Skip to content

Commit 06b6528

Browse files
committed
Refactor: Extract module
1 parent 2080c20 commit 06b6528

File tree

2 files changed

+345
-319
lines changed

2 files changed

+345
-319
lines changed

openvm/src/execution_stats.rs

Lines changed: 340 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,340 @@
1+
#![cfg_attr(feature = "tco", allow(internal_features))]
2+
#![cfg_attr(feature = "tco", allow(incomplete_features))]
3+
#![cfg_attr(feature = "tco", feature(explicit_tail_calls))]
4+
#![cfg_attr(feature = "tco", feature(core_intrinsics))]
5+
6+
use eyre::Result;
7+
use itertools::Itertools;
8+
use openvm_circuit::arch::execution_mode::metered::segment_ctx::SegmentationLimits;
9+
use openvm_circuit::arch::execution_mode::Segment;
10+
use openvm_circuit::arch::{PreflightExecutionOutput, VirtualMachine, VmCircuitConfig, VmInstance};
11+
use openvm_sdk::prover::vm::new_local_prover;
12+
use openvm_sdk::{
13+
config::{AppConfig, DEFAULT_APP_LOG_BLOWUP},
14+
StdIn,
15+
};
16+
use openvm_stark_backend::p3_matrix::dense::DenseMatrix;
17+
use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPermutationEngine;
18+
use openvm_stark_sdk::config::FriParameters;
19+
use openvm_stark_sdk::openvm_stark_backend::p3_field::PrimeField32;
20+
use openvm_stark_sdk::p3_baby_bear::BabyBear;
21+
use powdr_autoprecompiles::JsonExport;
22+
use std::collections::hash_map::Entry;
23+
use std::collections::BTreeMap;
24+
use std::{collections::HashMap, path::PathBuf, sync::Arc};
25+
26+
#[cfg(not(feature = "cuda"))]
27+
use crate::PowdrSdkCpu;
28+
use crate::{CompiledProgram, SpecializedConfigCpuBuilder};
29+
use tracing::info_span;
30+
31+
use std::collections::HashSet;
32+
use std::hash::Hash;
33+
34+
// ChatGPT generated code
35+
fn intersect_partitions<Id>(partitions: &[Vec<Vec<Id>>]) -> Vec<Vec<Id>>
36+
where
37+
Id: Eq + Hash + Copy,
38+
{
39+
if partitions.is_empty() {
40+
return Vec::new();
41+
}
42+
43+
// 1) For each partition, build a map: Id -> class_index
44+
let mut maps: Vec<HashMap<Id, usize>> = Vec::with_capacity(partitions.len());
45+
for part in partitions {
46+
let mut m = HashMap::new();
47+
for (class_idx, class) in part.iter().enumerate() {
48+
for &id in class {
49+
m.insert(id, class_idx);
50+
}
51+
}
52+
maps.push(m);
53+
}
54+
55+
// 2) Collect the universe of all Ids
56+
let mut universe: HashSet<Id> = HashSet::new();
57+
for part in partitions {
58+
for class in part {
59+
for &id in class {
60+
universe.insert(id);
61+
}
62+
}
63+
}
64+
65+
// 3) For each Id, build its "signature" of class indices across all partitions
66+
// and group by that signature.
67+
let mut grouped: HashMap<Vec<usize>, Vec<Id>> = HashMap::new();
68+
69+
for &id in &universe {
70+
let mut signature = Vec::with_capacity(maps.len());
71+
for m in &maps {
72+
let class_idx = m.get(&id).expect("id missing in some partition");
73+
signature.push(*class_idx);
74+
}
75+
grouped.entry(signature).or_default().push(id);
76+
}
77+
78+
// 4) Resulting equivalence classes are the grouped values
79+
grouped.into_values().collect()
80+
}
81+
82+
pub fn execution_stats(
83+
program: &CompiledProgram,
84+
inputs: StdIn,
85+
segment_height: Option<usize>, // uses the default height if None
86+
apc_candidates_dir: Option<PathBuf>,
87+
) -> Result<(), Box<dyn std::error::Error>> {
88+
let exe = &program.exe;
89+
let mut vm_config = program.vm_config.clone();
90+
91+
// DefaultSegmentationStrategy { max_segment_len: 4194204, max_cells_per_chip_in_segment: 503304480 }
92+
if let Some(segment_height) = segment_height {
93+
vm_config
94+
.sdk
95+
.config_mut()
96+
.sdk
97+
.system
98+
.config
99+
.segmentation_limits =
100+
SegmentationLimits::default().with_max_trace_height(segment_height as u32);
101+
tracing::debug!("Setting max segment len to {}", segment_height);
102+
}
103+
104+
// Set app configuration
105+
let app_fri_params =
106+
FriParameters::standard_with_100_bits_conjectured_security(DEFAULT_APP_LOG_BLOWUP);
107+
let app_config = AppConfig::new(app_fri_params, vm_config.clone());
108+
109+
// Create the SDK
110+
#[cfg(feature = "cuda")]
111+
let sdk = PowdrSdkGpu::new(app_config).unwrap();
112+
#[cfg(not(feature = "cuda"))]
113+
let sdk = PowdrSdkCpu::new(app_config).unwrap();
114+
// Build owned vm instance, so we can mutate it later
115+
let vm_builder = sdk.app_vm_builder().clone();
116+
let vm_pk = sdk.app_pk().app_vm_pk.clone();
117+
let exe = sdk.convert_to_exe(exe.clone())?;
118+
let mut vm_instance: VmInstance<_, _> = new_local_prover(vm_builder, &vm_pk, exe.clone())?;
119+
120+
vm_instance.reset_state(inputs.clone());
121+
let metered_ctx = vm_instance.vm.build_metered_ctx(&exe);
122+
let metered_interpreter = vm_instance.vm.metered_interpreter(vm_instance.exe())?;
123+
let (segments, _) = metered_interpreter.execute_metered(inputs.clone(), metered_ctx)?;
124+
let mut state = vm_instance.state_mut().take();
125+
126+
// Get reusable inputs for `debug_proving_ctx`, the mock prover API from OVM.
127+
let vm: &mut VirtualMachine<BabyBearPermutationEngine<_>, SpecializedConfigCpuBuilder> =
128+
&mut vm_instance.vm;
129+
130+
// Mapping (segment_idx, timestamp) -> Vec<u32>
131+
let mut rows_by_time = BTreeMap::new();
132+
133+
let mut trace_values_by_pc = HashMap::new();
134+
let mut column_names_by_air_id = HashMap::new();
135+
let mut air_id_by_pc = HashMap::new();
136+
137+
for (seg_idx, segment) in segments.into_iter().enumerate() {
138+
let _segment_span = info_span!("prove_segment", segment = seg_idx).entered();
139+
// We need a separate span so the metric label includes "segment" from _segment_span
140+
let _prove_span = info_span!("total_proof").entered();
141+
let Segment {
142+
instret_start,
143+
num_insns,
144+
trace_heights,
145+
} = segment;
146+
assert_eq!(state.as_ref().unwrap().instret(), instret_start);
147+
let from_state = Option::take(&mut state).unwrap();
148+
vm.transport_init_memory_to_device(&from_state.memory);
149+
let PreflightExecutionOutput {
150+
system_records,
151+
record_arenas,
152+
to_state,
153+
} = vm.execute_preflight(
154+
&mut vm_instance.interpreter,
155+
from_state,
156+
Some(num_insns),
157+
&trace_heights,
158+
)?;
159+
state = Some(to_state);
160+
161+
// Generate proving context for each segment
162+
let ctx = vm.generate_proving_ctx(system_records, record_arenas)?;
163+
164+
let global_airs = vm
165+
.config()
166+
.create_airs()
167+
.unwrap()
168+
.into_airs()
169+
.enumerate()
170+
.collect::<HashMap<_, _>>();
171+
172+
for (air_id, proving_context) in &ctx.per_air {
173+
if !proving_context.cached_mains.is_empty() {
174+
// Not the case for instruction circuits
175+
continue;
176+
}
177+
let main: &Arc<DenseMatrix<BabyBear>> = proving_context.common_main.as_ref().unwrap();
178+
179+
let air = &global_airs[air_id];
180+
let Some(column_names) = air.columns() else {
181+
continue;
182+
};
183+
assert_eq!(main.width, column_names.len());
184+
185+
// This is the case for all instruction circuits
186+
let Some(pc_index) = column_names
187+
.iter()
188+
.position(|name| name == "from_state__pc")
189+
else {
190+
continue;
191+
};
192+
let ts_index = 1;
193+
194+
for row in main.row_slices() {
195+
let row = row.iter().map(|v| v.as_canonical_u32()).collect::<Vec<_>>();
196+
let pc_value = row[pc_index];
197+
let ts_value = row[ts_index];
198+
rows_by_time.insert((seg_idx, ts_value), row.clone());
199+
200+
if pc_value == 0 {
201+
// Padding row!
202+
continue;
203+
}
204+
205+
if let Entry::Vacant(e) = trace_values_by_pc.entry(pc_value) {
206+
// First time we see this PC, initialize the column -> values map
207+
e.insert(vec![Vec::new(); row.len()]);
208+
column_names_by_air_id.insert(*air_id, column_names.clone());
209+
air_id_by_pc.insert(pc_value, *air_id);
210+
}
211+
let values_by_col = trace_values_by_pc.get_mut(&pc_value).unwrap();
212+
assert_eq!(
213+
air_id_by_pc[&pc_value],
214+
*air_id,
215+
"Mismatched air IDs for PC {}: {} vs {}",
216+
pc_value,
217+
global_airs[&air_id_by_pc[&pc_value]].name(),
218+
air.name()
219+
);
220+
assert_eq!(values_by_col.len(), row.len());
221+
222+
for (col_idx, value) in row.iter().enumerate() {
223+
values_by_col[col_idx].push(*value);
224+
}
225+
}
226+
}
227+
}
228+
229+
let apc_candidates_dir = apc_candidates_dir.unwrap();
230+
let apc_candiates: powdr_autoprecompiles::pgo::JsonExport = {
231+
let json_str =
232+
std::fs::read_to_string(apc_candidates_dir.join("apc_candidates.json")).unwrap();
233+
serde_json::from_str(&json_str).unwrap()
234+
};
235+
let apcs = apc_candiates.apcs;
236+
237+
// Block ID -> instruction count mapping
238+
let instruction_counts = apcs
239+
.iter()
240+
.map(|apc| {
241+
(
242+
apc.original_block.start_pc,
243+
apc.original_block.statements.len(),
244+
)
245+
})
246+
.collect::<HashMap<_, _>>();
247+
248+
// Block ID -> Vec<Vec<Row>>
249+
let mut block_rows = BTreeMap::new();
250+
let mut i = 0;
251+
let rows_by_time = rows_by_time.values().collect::<Vec<_>>();
252+
while i < rows_by_time.len() {
253+
let row = &rows_by_time[i];
254+
let pc_value = row[0] as u64;
255+
256+
if instruction_counts.contains_key(&pc_value) {
257+
let instruction_count = instruction_counts[&pc_value];
258+
let block_row_slice = &rows_by_time[i..i + instruction_count];
259+
block_rows
260+
.entry(pc_value)
261+
.or_insert(Vec::new())
262+
.push(block_row_slice.to_vec());
263+
i += instruction_count;
264+
} else {
265+
i += 1;
266+
}
267+
}
268+
269+
// Block ID -> Vec<Vec<Vec<(instruction_index, col_index)>>>:
270+
// Indices: block ID, instance idx, equivalence class idx, cell
271+
let equivalence_classes = block_rows
272+
.into_iter()
273+
.map(|(block_id, blocks)| {
274+
let classes = blocks
275+
.into_iter()
276+
.map(|rows| {
277+
let value_to_cells = rows
278+
.into_iter()
279+
.enumerate()
280+
.flat_map(|(instruction_index, row)| {
281+
row.iter()
282+
.enumerate()
283+
.map(|(col_index, v)| (*v, (instruction_index, col_index)))
284+
.collect::<Vec<_>>()
285+
})
286+
.into_group_map();
287+
value_to_cells.values().cloned().collect::<Vec<_>>()
288+
})
289+
.collect::<Vec<_>>();
290+
(block_id, classes)
291+
})
292+
.collect::<HashMap<_, _>>();
293+
294+
// Intersect equivalence classes across all instances
295+
let intersected_equivalence_classes = equivalence_classes
296+
.into_iter()
297+
.map(|(block_id, classes)| {
298+
let intersected = intersect_partitions(&classes);
299+
300+
// Remove singleton classes
301+
let intersected = intersected
302+
.into_iter()
303+
.filter(|class| class.len() > 1)
304+
.collect::<Vec<_>>();
305+
306+
(block_id, intersected)
307+
})
308+
.collect::<BTreeMap<_, _>>();
309+
310+
// Map all column values to their range (1st and 99th percentile) for each pc
311+
let column_ranges_by_pc: HashMap<u32, Vec<(u32, u32)>> = trace_values_by_pc
312+
.into_iter()
313+
.map(|(pc, values_by_col)| {
314+
let column_ranges = values_by_col
315+
.into_iter()
316+
.map(|mut values| {
317+
values.sort_unstable();
318+
let len = values.len();
319+
let p1_index = len / 100; // 1st percentile
320+
let p99_index = len * 99 / 100; // 99th percentile
321+
(values[p1_index], values[p99_index])
322+
})
323+
.collect();
324+
(pc, column_ranges)
325+
})
326+
.collect();
327+
328+
let export = JsonExport {
329+
air_id_by_pc: air_id_by_pc.into_iter().collect(),
330+
column_names_by_air_id: column_names_by_air_id.into_iter().collect(),
331+
column_ranges_by_pc: column_ranges_by_pc.into_iter().collect(),
332+
equivalence_classes_by_block: intersected_equivalence_classes,
333+
};
334+
335+
// Write to pgo_range_constraints.json
336+
let json = serde_json::to_string_pretty(&export).unwrap();
337+
std::fs::write("pgo_range_constraints.json", json).unwrap();
338+
339+
Ok(())
340+
}

0 commit comments

Comments
 (0)