1- #![ cfg_attr( feature = "tco" , allow( internal_features) ) ]
2- #![ cfg_attr( feature = "tco" , allow( incomplete_features) ) ]
3- #![ cfg_attr( feature = "tco" , feature( explicit_tail_calls) ) ]
4- #![ cfg_attr( feature = "tco" , feature( core_intrinsics) ) ]
5-
6- use eyre:: Result ;
71use itertools:: Itertools ;
8- use openvm_circuit:: arch:: execution_mode:: Segment ;
9- use openvm_circuit:: arch:: { PreflightExecutionOutput , VirtualMachine , VmCircuitConfig , VmInstance } ;
2+ use openvm_circuit:: arch:: VmCircuitConfig ;
103use openvm_instructions:: exe:: VmExe ;
11- use openvm_sdk:: prover:: vm:: new_local_prover;
12- use openvm_sdk:: {
13- config:: { AppConfig , DEFAULT_APP_LOG_BLOWUP } ,
14- StdIn ,
15- } ;
4+ use openvm_sdk:: StdIn ;
165use openvm_stark_backend:: p3_matrix:: dense:: DenseMatrix ;
17- use openvm_stark_sdk:: config:: baby_bear_poseidon2:: BabyBearPermutationEngine ;
18- use openvm_stark_sdk:: config:: FriParameters ;
196use openvm_stark_sdk:: openvm_stark_backend:: p3_field:: PrimeField32 ;
207use openvm_stark_sdk:: p3_baby_bear:: BabyBear ;
218use powdr_autoprecompiles:: blocks:: BasicBlock ;
@@ -24,136 +11,36 @@ use std::collections::hash_map::Entry;
2411use std:: collections:: BTreeMap ;
2512use std:: { collections:: HashMap , sync:: Arc } ;
2613
27- #[ cfg( not( feature = "cuda" ) ) ]
28- use crate :: PowdrSdkCpu ;
29- use crate :: { Instr , SpecializedConfig , SpecializedConfigCpuBuilder } ;
30- use tracing:: info_span;
14+ use crate :: trace_generation:: do_with_trace;
15+ use crate :: { Instr , SpecializedConfig } ;
3116
3217use std:: collections:: HashSet ;
3318use std:: hash:: Hash ;
3419
35- // ChatGPT generated code
36- fn intersect_partitions < Id > ( partitions : & [ Vec < Vec < Id > > ] ) -> Vec < Vec < Id > >
37- where
38- Id : Eq + Hash + Copy ,
39- {
40- if partitions. is_empty ( ) {
41- return Vec :: new ( ) ;
42- }
43-
44- // 1) For each partition, build a map: Id -> class_index
45- let mut maps: Vec < HashMap < Id , usize > > = Vec :: with_capacity ( partitions. len ( ) ) ;
46- for part in partitions {
47- let mut m = HashMap :: new ( ) ;
48- for ( class_idx, class) in part. iter ( ) . enumerate ( ) {
49- for & id in class {
50- m. insert ( id, class_idx) ;
51- }
52- }
53- maps. push ( m) ;
54- }
55-
56- // 2) Collect the universe of all Ids
57- let mut universe: HashSet < Id > = HashSet :: new ( ) ;
58- for part in partitions {
59- for class in part {
60- for & id in class {
61- universe. insert ( id) ;
62- }
63- }
64- }
65-
66- // 3) For each Id, build its "signature" of class indices across all partitions
67- // and group by that signature.
68- let mut grouped: HashMap < Vec < usize > , Vec < Id > > = HashMap :: new ( ) ;
69-
70- for & id in & universe {
71- let mut signature = Vec :: with_capacity ( maps. len ( ) ) ;
72- let mut is_singleton = false ;
73- for m in & maps {
74- let Some ( class_idx) = m. get ( & id) else {
75- // The element did not appear in one of the partition, so it is its
76- // own equivalence class. We can also omit it in the output partition.
77- is_singleton = true ;
78- break ;
79- } ;
80- signature. push ( * class_idx) ;
81- }
82- if !is_singleton {
83- grouped. entry ( signature) . or_default ( ) . push ( id) ;
84- }
85- }
86-
87- // 4) Resulting equivalence classes are the grouped values
88- grouped. into_values ( ) . collect ( )
20+ #[ derive( Default ) ]
21+ struct Trace {
22+ /// Mapping (segment_idx, timestamp) -> Vec<u32>
23+ rows_by_time : BTreeMap < ( usize , u32 ) , Vec < u32 > > ,
24+ trace_values_by_pc : HashMap < u32 , Vec < Vec < u32 > > > ,
25+ air_id_by_pc : HashMap < u32 , usize > ,
26+ column_names_by_air_id : HashMap < usize , Vec < String > > ,
8927}
9028
9129pub fn execution_stats (
9230 exe : Arc < VmExe < BabyBear > > ,
9331 vm_config : SpecializedConfig ,
9432 inputs : StdIn ,
9533 blocks : & [ BasicBlock < Instr < BabyBear > > ] ,
96- ) -> Result < ExecutionStats , Box < dyn std:: error:: Error > > {
97- // Set app configuration
98- let app_fri_params =
99- FriParameters :: standard_with_100_bits_conjectured_security ( DEFAULT_APP_LOG_BLOWUP ) ;
100- let app_config = AppConfig :: new ( app_fri_params, vm_config. clone ( ) ) ;
101-
102- // Create the SDK
103- #[ cfg( feature = "cuda" ) ]
104- let sdk = PowdrSdkGpu :: new ( app_config) . unwrap ( ) ;
105- #[ cfg( not( feature = "cuda" ) ) ]
106- let sdk = PowdrSdkCpu :: new ( app_config) . unwrap ( ) ;
107- // Build owned vm instance, so we can mutate it later
108- let vm_builder = sdk. app_vm_builder ( ) . clone ( ) ;
109- let vm_pk = sdk. app_pk ( ) . app_vm_pk . clone ( ) ;
110- let exe = sdk. convert_to_exe ( exe. clone ( ) ) ?;
111- let mut vm_instance: VmInstance < _ , _ > = new_local_prover ( vm_builder, & vm_pk, exe. clone ( ) ) ?;
112-
113- vm_instance. reset_state ( inputs. clone ( ) ) ;
114- let metered_ctx = vm_instance. vm . build_metered_ctx ( & exe) ;
115- let metered_interpreter = vm_instance. vm . metered_interpreter ( vm_instance. exe ( ) ) ?;
116- let ( segments, _) = metered_interpreter. execute_metered ( inputs. clone ( ) , metered_ctx) ?;
117- let mut state = vm_instance. state_mut ( ) . take ( ) ;
118-
119- // Get reusable inputs for `debug_proving_ctx`, the mock prover API from OVM.
120- let vm: & mut VirtualMachine < BabyBearPermutationEngine < _ > , SpecializedConfigCpuBuilder > =
121- & mut vm_instance. vm ;
122-
123- // Mapping (segment_idx, timestamp) -> Vec<u32>
124- let mut rows_by_time = BTreeMap :: new ( ) ;
125-
126- let mut trace_values_by_pc = HashMap :: new ( ) ;
127- let mut column_names_by_air_id = HashMap :: new ( ) ;
128- let mut air_id_by_pc = HashMap :: new ( ) ;
129-
130- for ( seg_idx, segment) in segments. into_iter ( ) . enumerate ( ) {
131- let _segment_span = info_span ! ( "prove_segment" , segment = seg_idx) . entered ( ) ;
132- // We need a separate span so the metric label includes "segment" from _segment_span
133- let _prove_span = info_span ! ( "total_proof" ) . entered ( ) ;
134- let Segment {
135- instret_start,
136- num_insns,
137- trace_heights,
138- } = segment;
139- assert_eq ! ( state. as_ref( ) . unwrap( ) . instret( ) , instret_start) ;
140- let from_state = Option :: take ( & mut state) . unwrap ( ) ;
141- vm. transport_init_memory_to_device ( & from_state. memory ) ;
142- let PreflightExecutionOutput {
143- system_records,
144- record_arenas,
145- to_state,
146- } = vm. execute_preflight (
147- & mut vm_instance. interpreter ,
148- from_state,
149- Some ( num_insns) ,
150- & trace_heights,
151- ) ?;
152- state = Some ( to_state) ;
34+ ) -> ExecutionStats {
35+ let trace = collect_trace ( exe, vm_config, inputs) ;
36+ generate_execution_stats ( blocks, trace)
37+ }
15338
154- // Generate proving context for each segment
155- let ctx = vm. generate_proving_ctx ( system_records, record_arenas) ?;
39+ fn collect_trace ( exe : Arc < VmExe < BabyBear > > , vm_config : SpecializedConfig , inputs : StdIn ) -> Trace {
40+ let mut trace = Trace :: default ( ) ;
41+ let mut seg_idx = 0 ;
15642
43+ do_with_trace ( exe, vm_config, inputs, |vm, _pk, ctx| {
15744 let global_airs = vm
15845 . config ( )
15946 . create_airs ( )
@@ -188,26 +75,28 @@ pub fn execution_stats(
18875 let row = row. iter ( ) . map ( |v| v. as_canonical_u32 ( ) ) . collect :: < Vec < _ > > ( ) ;
18976 let pc_value = row[ pc_index] ;
19077 let ts_value = row[ ts_index] ;
191- rows_by_time. insert ( ( seg_idx, ts_value) , row. clone ( ) ) ;
78+ trace . rows_by_time . insert ( ( seg_idx, ts_value) , row. clone ( ) ) ;
19279
19380 if pc_value == 0 {
19481 // Padding row!
19582 continue ;
19683 }
19784
198- if let Entry :: Vacant ( e) = trace_values_by_pc. entry ( pc_value) {
85+ if let Entry :: Vacant ( e) = trace . trace_values_by_pc . entry ( pc_value) {
19986 // First time we see this PC, initialize the column -> values map
20087 e. insert ( vec ! [ Vec :: new( ) ; row. len( ) ] ) ;
201- column_names_by_air_id. insert ( * air_id, column_names. clone ( ) ) ;
202- air_id_by_pc. insert ( pc_value, * air_id) ;
88+ trace
89+ . column_names_by_air_id
90+ . insert ( * air_id, column_names. clone ( ) ) ;
91+ trace. air_id_by_pc . insert ( pc_value, * air_id) ;
20392 }
204- let values_by_col = trace_values_by_pc. get_mut ( & pc_value) . unwrap ( ) ;
93+ let values_by_col = trace . trace_values_by_pc . get_mut ( & pc_value) . unwrap ( ) ;
20594 assert_eq ! (
206- air_id_by_pc[ & pc_value] ,
95+ trace . air_id_by_pc[ & pc_value] ,
20796 * air_id,
20897 "Mismatched air IDs for PC {}: {} vs {}" ,
20998 pc_value,
210- global_airs[ & air_id_by_pc[ & pc_value] ] . name( ) ,
99+ global_airs[ & trace . air_id_by_pc[ & pc_value] ] . name( ) ,
211100 air. name( )
212101 ) ;
213102 assert_eq ! ( values_by_col. len( ) , row. len( ) ) ;
@@ -217,8 +106,15 @@ pub fn execution_stats(
217106 }
218107 }
219108 }
220- }
109+ seg_idx += 1 ;
110+ } ) ;
111+ trace
112+ }
221113
114+ fn generate_execution_stats (
115+ blocks : & [ BasicBlock < Instr < BabyBear > > ] ,
116+ trace : Trace ,
117+ ) -> ExecutionStats {
222118 // Block ID -> instruction count mapping
223119 let instruction_counts = blocks
224120 . iter ( )
@@ -228,7 +124,7 @@ pub fn execution_stats(
228124 // Block ID -> Vec<Vec<Row>>
229125 let mut block_rows = BTreeMap :: new ( ) ;
230126 let mut i = 0 ;
231- let rows_by_time = rows_by_time. values ( ) . collect :: < Vec < _ > > ( ) ;
127+ let rows_by_time = trace . rows_by_time . values ( ) . collect :: < Vec < _ > > ( ) ;
232128 while i < rows_by_time. len ( ) {
233129 let row = & rows_by_time[ i] ;
234130 let pc_value = row[ 0 ] as u64 ;
@@ -288,7 +184,8 @@ pub fn execution_stats(
288184 . collect :: < BTreeMap < _ , _ > > ( ) ;
289185
290186 // Map all column values to their range (1st and 99th percentile) for each pc
291- let column_ranges_by_pc: HashMap < u32 , Vec < ( u32 , u32 ) > > = trace_values_by_pc
187+ let column_ranges_by_pc: HashMap < u32 , Vec < ( u32 , u32 ) > > = trace
188+ . trace_values_by_pc
292189 . into_iter ( )
293190 . map ( |( pc, values_by_col) | {
294191 let column_ranges = values_by_col
@@ -306,8 +203,8 @@ pub fn execution_stats(
306203 . collect ( ) ;
307204
308205 let export = ExecutionStats {
309- air_id_by_pc : air_id_by_pc. into_iter ( ) . collect ( ) ,
310- column_names_by_air_id : column_names_by_air_id. into_iter ( ) . collect ( ) ,
206+ air_id_by_pc : trace . air_id_by_pc . into_iter ( ) . collect ( ) ,
207+ column_names_by_air_id : trace . column_names_by_air_id . into_iter ( ) . collect ( ) ,
311208 column_ranges_by_pc : column_ranges_by_pc. into_iter ( ) . collect ( ) ,
312209 equivalence_classes_by_block : intersected_equivalence_classes,
313210 } ;
@@ -316,5 +213,61 @@ pub fn execution_stats(
316213 let json = serde_json:: to_string_pretty ( & export) . unwrap ( ) ;
317214 std:: fs:: write ( "pgo_range_constraints.json" , json) . unwrap ( ) ;
318215
319- Ok ( export)
216+ export
217+ }
218+
219+ // ChatGPT generated code
220+ fn intersect_partitions < Id > ( partitions : & [ Vec < Vec < Id > > ] ) -> Vec < Vec < Id > >
221+ where
222+ Id : Eq + Hash + Copy ,
223+ {
224+ if partitions. is_empty ( ) {
225+ return Vec :: new ( ) ;
226+ }
227+
228+ // 1) For each partition, build a map: Id -> class_index
229+ let mut maps: Vec < HashMap < Id , usize > > = Vec :: with_capacity ( partitions. len ( ) ) ;
230+ for part in partitions {
231+ let mut m = HashMap :: new ( ) ;
232+ for ( class_idx, class) in part. iter ( ) . enumerate ( ) {
233+ for & id in class {
234+ m. insert ( id, class_idx) ;
235+ }
236+ }
237+ maps. push ( m) ;
238+ }
239+
240+ // 2) Collect the universe of all Ids
241+ let mut universe: HashSet < Id > = HashSet :: new ( ) ;
242+ for part in partitions {
243+ for class in part {
244+ for & id in class {
245+ universe. insert ( id) ;
246+ }
247+ }
248+ }
249+
250+ // 3) For each Id, build its "signature" of class indices across all partitions
251+ // and group by that signature.
252+ let mut grouped: HashMap < Vec < usize > , Vec < Id > > = HashMap :: new ( ) ;
253+
254+ for & id in & universe {
255+ let mut signature = Vec :: with_capacity ( maps. len ( ) ) ;
256+ let mut is_singleton = false ;
257+ for m in & maps {
258+ let Some ( class_idx) = m. get ( & id) else {
259+ // The element did not appear in one of the partition, so it is its
260+ // own equivalence class. We can also omit it in the output partition.
261+ is_singleton = true ;
262+ break ;
263+ } ;
264+ signature. push ( * class_idx) ;
265+ }
266+ if !is_singleton {
267+ grouped. entry ( signature) . or_default ( ) . push ( id) ;
268+ }
269+ }
270+
271+ // 4) Resulting equivalence classes are the grouped values
272+ grouped. into_values ( ) . collect ( )
320273}
0 commit comments