Skip to content

Commit 7851df5

Browse files
committed
Refactor
1 parent cfddc9a commit 7851df5

File tree

4 files changed

+114
-156
lines changed

4 files changed

+114
-156
lines changed

openvm/src/customize_exe.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,7 @@ pub fn customize<'a, P: PgoAdapter<Adapter = BabyBearOpenVmApcAdapter<'a>>>(
228228
),
229229
inputs,
230230
&blocks,
231-
)
232-
.unwrap();
231+
);
233232

234233
let start = std::time::Instant::now();
235234
let apcs = pgo.filter_blocks_and_create_apcs_with_pgo(

openvm/src/execution_stats.rs

Lines changed: 98 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,8 @@
1-
#![cfg_attr(feature = "tco", allow(internal_features))]
2-
#![cfg_attr(feature = "tco", allow(incomplete_features))]
3-
#![cfg_attr(feature = "tco", feature(explicit_tail_calls))]
4-
#![cfg_attr(feature = "tco", feature(core_intrinsics))]
5-
6-
use eyre::Result;
71
use itertools::Itertools;
8-
use openvm_circuit::arch::execution_mode::Segment;
9-
use openvm_circuit::arch::{PreflightExecutionOutput, VirtualMachine, VmCircuitConfig, VmInstance};
2+
use openvm_circuit::arch::VmCircuitConfig;
103
use openvm_instructions::exe::VmExe;
11-
use openvm_sdk::prover::vm::new_local_prover;
12-
use openvm_sdk::{
13-
config::{AppConfig, DEFAULT_APP_LOG_BLOWUP},
14-
StdIn,
15-
};
4+
use openvm_sdk::StdIn;
165
use openvm_stark_backend::p3_matrix::dense::DenseMatrix;
17-
use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPermutationEngine;
18-
use openvm_stark_sdk::config::FriParameters;
196
use openvm_stark_sdk::openvm_stark_backend::p3_field::PrimeField32;
207
use openvm_stark_sdk::p3_baby_bear::BabyBear;
218
use powdr_autoprecompiles::blocks::BasicBlock;
@@ -24,136 +11,36 @@ use std::collections::hash_map::Entry;
2411
use std::collections::BTreeMap;
2512
use std::{collections::HashMap, sync::Arc};
2613

27-
#[cfg(not(feature = "cuda"))]
28-
use crate::PowdrSdkCpu;
29-
use crate::{Instr, SpecializedConfig, SpecializedConfigCpuBuilder};
30-
use tracing::info_span;
14+
use crate::trace_generation::do_with_trace;
15+
use crate::{Instr, SpecializedConfig};
3116

3217
use std::collections::HashSet;
3318
use std::hash::Hash;
3419

35-
// ChatGPT generated code
36-
fn intersect_partitions<Id>(partitions: &[Vec<Vec<Id>>]) -> Vec<Vec<Id>>
37-
where
38-
Id: Eq + Hash + Copy,
39-
{
40-
if partitions.is_empty() {
41-
return Vec::new();
42-
}
43-
44-
// 1) For each partition, build a map: Id -> class_index
45-
let mut maps: Vec<HashMap<Id, usize>> = Vec::with_capacity(partitions.len());
46-
for part in partitions {
47-
let mut m = HashMap::new();
48-
for (class_idx, class) in part.iter().enumerate() {
49-
for &id in class {
50-
m.insert(id, class_idx);
51-
}
52-
}
53-
maps.push(m);
54-
}
55-
56-
// 2) Collect the universe of all Ids
57-
let mut universe: HashSet<Id> = HashSet::new();
58-
for part in partitions {
59-
for class in part {
60-
for &id in class {
61-
universe.insert(id);
62-
}
63-
}
64-
}
65-
66-
// 3) For each Id, build its "signature" of class indices across all partitions
67-
// and group by that signature.
68-
let mut grouped: HashMap<Vec<usize>, Vec<Id>> = HashMap::new();
69-
70-
for &id in &universe {
71-
let mut signature = Vec::with_capacity(maps.len());
72-
let mut is_singleton = false;
73-
for m in &maps {
74-
let Some(class_idx) = m.get(&id) else {
75-
// The element did not appear in one of the partition, so it is its
76-
// own equivalence class. We can also omit it in the output partition.
77-
is_singleton = true;
78-
break;
79-
};
80-
signature.push(*class_idx);
81-
}
82-
if !is_singleton {
83-
grouped.entry(signature).or_default().push(id);
84-
}
85-
}
86-
87-
// 4) Resulting equivalence classes are the grouped values
88-
grouped.into_values().collect()
20+
#[derive(Default)]
21+
struct Trace {
22+
/// Mapping (segment_idx, timestamp) -> Vec<u32>
23+
rows_by_time: BTreeMap<(usize, u32), Vec<u32>>,
24+
trace_values_by_pc: HashMap<u32, Vec<Vec<u32>>>,
25+
air_id_by_pc: HashMap<u32, usize>,
26+
column_names_by_air_id: HashMap<usize, Vec<String>>,
8927
}
9028

9129
pub fn execution_stats(
9230
exe: Arc<VmExe<BabyBear>>,
9331
vm_config: SpecializedConfig,
9432
inputs: StdIn,
9533
blocks: &[BasicBlock<Instr<BabyBear>>],
96-
) -> Result<ExecutionStats, Box<dyn std::error::Error>> {
97-
// Set app configuration
98-
let app_fri_params =
99-
FriParameters::standard_with_100_bits_conjectured_security(DEFAULT_APP_LOG_BLOWUP);
100-
let app_config = AppConfig::new(app_fri_params, vm_config.clone());
101-
102-
// Create the SDK
103-
#[cfg(feature = "cuda")]
104-
let sdk = PowdrSdkGpu::new(app_config).unwrap();
105-
#[cfg(not(feature = "cuda"))]
106-
let sdk = PowdrSdkCpu::new(app_config).unwrap();
107-
// Build owned vm instance, so we can mutate it later
108-
let vm_builder = sdk.app_vm_builder().clone();
109-
let vm_pk = sdk.app_pk().app_vm_pk.clone();
110-
let exe = sdk.convert_to_exe(exe.clone())?;
111-
let mut vm_instance: VmInstance<_, _> = new_local_prover(vm_builder, &vm_pk, exe.clone())?;
112-
113-
vm_instance.reset_state(inputs.clone());
114-
let metered_ctx = vm_instance.vm.build_metered_ctx(&exe);
115-
let metered_interpreter = vm_instance.vm.metered_interpreter(vm_instance.exe())?;
116-
let (segments, _) = metered_interpreter.execute_metered(inputs.clone(), metered_ctx)?;
117-
let mut state = vm_instance.state_mut().take();
118-
119-
// Get reusable inputs for `debug_proving_ctx`, the mock prover API from OVM.
120-
let vm: &mut VirtualMachine<BabyBearPermutationEngine<_>, SpecializedConfigCpuBuilder> =
121-
&mut vm_instance.vm;
122-
123-
// Mapping (segment_idx, timestamp) -> Vec<u32>
124-
let mut rows_by_time = BTreeMap::new();
125-
126-
let mut trace_values_by_pc = HashMap::new();
127-
let mut column_names_by_air_id = HashMap::new();
128-
let mut air_id_by_pc = HashMap::new();
129-
130-
for (seg_idx, segment) in segments.into_iter().enumerate() {
131-
let _segment_span = info_span!("prove_segment", segment = seg_idx).entered();
132-
// We need a separate span so the metric label includes "segment" from _segment_span
133-
let _prove_span = info_span!("total_proof").entered();
134-
let Segment {
135-
instret_start,
136-
num_insns,
137-
trace_heights,
138-
} = segment;
139-
assert_eq!(state.as_ref().unwrap().instret(), instret_start);
140-
let from_state = Option::take(&mut state).unwrap();
141-
vm.transport_init_memory_to_device(&from_state.memory);
142-
let PreflightExecutionOutput {
143-
system_records,
144-
record_arenas,
145-
to_state,
146-
} = vm.execute_preflight(
147-
&mut vm_instance.interpreter,
148-
from_state,
149-
Some(num_insns),
150-
&trace_heights,
151-
)?;
152-
state = Some(to_state);
34+
) -> ExecutionStats {
35+
let trace = collect_trace(exe, vm_config, inputs);
36+
generate_execution_stats(blocks, trace)
37+
}
15338

154-
// Generate proving context for each segment
155-
let ctx = vm.generate_proving_ctx(system_records, record_arenas)?;
39+
fn collect_trace(exe: Arc<VmExe<BabyBear>>, vm_config: SpecializedConfig, inputs: StdIn) -> Trace {
40+
let mut trace = Trace::default();
41+
let mut seg_idx = 0;
15642

43+
do_with_trace(exe, vm_config, inputs, |vm, _pk, ctx| {
15744
let global_airs = vm
15845
.config()
15946
.create_airs()
@@ -188,26 +75,28 @@ pub fn execution_stats(
18875
let row = row.iter().map(|v| v.as_canonical_u32()).collect::<Vec<_>>();
18976
let pc_value = row[pc_index];
19077
let ts_value = row[ts_index];
191-
rows_by_time.insert((seg_idx, ts_value), row.clone());
78+
trace.rows_by_time.insert((seg_idx, ts_value), row.clone());
19279

19380
if pc_value == 0 {
19481
// Padding row!
19582
continue;
19683
}
19784

198-
if let Entry::Vacant(e) = trace_values_by_pc.entry(pc_value) {
85+
if let Entry::Vacant(e) = trace.trace_values_by_pc.entry(pc_value) {
19986
// First time we see this PC, initialize the column -> values map
20087
e.insert(vec![Vec::new(); row.len()]);
201-
column_names_by_air_id.insert(*air_id, column_names.clone());
202-
air_id_by_pc.insert(pc_value, *air_id);
88+
trace
89+
.column_names_by_air_id
90+
.insert(*air_id, column_names.clone());
91+
trace.air_id_by_pc.insert(pc_value, *air_id);
20392
}
204-
let values_by_col = trace_values_by_pc.get_mut(&pc_value).unwrap();
93+
let values_by_col = trace.trace_values_by_pc.get_mut(&pc_value).unwrap();
20594
assert_eq!(
206-
air_id_by_pc[&pc_value],
95+
trace.air_id_by_pc[&pc_value],
20796
*air_id,
20897
"Mismatched air IDs for PC {}: {} vs {}",
20998
pc_value,
210-
global_airs[&air_id_by_pc[&pc_value]].name(),
99+
global_airs[&trace.air_id_by_pc[&pc_value]].name(),
211100
air.name()
212101
);
213102
assert_eq!(values_by_col.len(), row.len());
@@ -217,8 +106,15 @@ pub fn execution_stats(
217106
}
218107
}
219108
}
220-
}
109+
seg_idx += 1;
110+
});
111+
trace
112+
}
221113

114+
fn generate_execution_stats(
115+
blocks: &[BasicBlock<Instr<BabyBear>>],
116+
trace: Trace,
117+
) -> ExecutionStats {
222118
// Block ID -> instruction count mapping
223119
let instruction_counts = blocks
224120
.iter()
@@ -228,7 +124,7 @@ pub fn execution_stats(
228124
// Block ID -> Vec<Vec<Row>>
229125
let mut block_rows = BTreeMap::new();
230126
let mut i = 0;
231-
let rows_by_time = rows_by_time.values().collect::<Vec<_>>();
127+
let rows_by_time = trace.rows_by_time.values().collect::<Vec<_>>();
232128
while i < rows_by_time.len() {
233129
let row = &rows_by_time[i];
234130
let pc_value = row[0] as u64;
@@ -288,7 +184,8 @@ pub fn execution_stats(
288184
.collect::<BTreeMap<_, _>>();
289185

290186
// Map all column values to their range (1st and 99th percentile) for each pc
291-
let column_ranges_by_pc: HashMap<u32, Vec<(u32, u32)>> = trace_values_by_pc
187+
let column_ranges_by_pc: HashMap<u32, Vec<(u32, u32)>> = trace
188+
.trace_values_by_pc
292189
.into_iter()
293190
.map(|(pc, values_by_col)| {
294191
let column_ranges = values_by_col
@@ -306,8 +203,8 @@ pub fn execution_stats(
306203
.collect();
307204

308205
let export = ExecutionStats {
309-
air_id_by_pc: air_id_by_pc.into_iter().collect(),
310-
column_names_by_air_id: column_names_by_air_id.into_iter().collect(),
206+
air_id_by_pc: trace.air_id_by_pc.into_iter().collect(),
207+
column_names_by_air_id: trace.column_names_by_air_id.into_iter().collect(),
311208
column_ranges_by_pc: column_ranges_by_pc.into_iter().collect(),
312209
equivalence_classes_by_block: intersected_equivalence_classes,
313210
};
@@ -316,5 +213,61 @@ pub fn execution_stats(
316213
let json = serde_json::to_string_pretty(&export).unwrap();
317214
std::fs::write("pgo_range_constraints.json", json).unwrap();
318215

319-
Ok(export)
216+
export
217+
}
218+
219+
// ChatGPT generated code
220+
fn intersect_partitions<Id>(partitions: &[Vec<Vec<Id>>]) -> Vec<Vec<Id>>
221+
where
222+
Id: Eq + Hash + Copy,
223+
{
224+
if partitions.is_empty() {
225+
return Vec::new();
226+
}
227+
228+
// 1) For each partition, build a map: Id -> class_index
229+
let mut maps: Vec<HashMap<Id, usize>> = Vec::with_capacity(partitions.len());
230+
for part in partitions {
231+
let mut m = HashMap::new();
232+
for (class_idx, class) in part.iter().enumerate() {
233+
for &id in class {
234+
m.insert(id, class_idx);
235+
}
236+
}
237+
maps.push(m);
238+
}
239+
240+
// 2) Collect the universe of all Ids
241+
let mut universe: HashSet<Id> = HashSet::new();
242+
for part in partitions {
243+
for class in part {
244+
for &id in class {
245+
universe.insert(id);
246+
}
247+
}
248+
}
249+
250+
// 3) For each Id, build its "signature" of class indices across all partitions
251+
// and group by that signature.
252+
let mut grouped: HashMap<Vec<usize>, Vec<Id>> = HashMap::new();
253+
254+
for &id in &universe {
255+
let mut signature = Vec::with_capacity(maps.len());
256+
let mut is_singleton = false;
257+
for m in &maps {
258+
let Some(class_idx) = m.get(&id) else {
259+
// The element did not appear in one of the partition, so it is its
260+
// own equivalence class. We can also omit it in the output partition.
261+
is_singleton = true;
262+
break;
263+
};
264+
signature.push(*class_idx);
265+
}
266+
if !is_singleton {
267+
grouped.entry(signature).or_default().push(id);
268+
}
269+
}
270+
271+
// 4) Resulting equivalence classes are the grouped values
272+
grouped.into_values().collect()
320273
}

openvm/src/lib.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -880,9 +880,14 @@ pub fn prove(
880880
#[cfg(not(feature = "cuda"))]
881881
let sdk = PowdrSdkCpu::new(app_config).unwrap();
882882
if mock {
883-
do_with_trace(program, inputs, |vm, pk, ctx| {
884-
debug_proving_ctx(vm, pk, &ctx);
885-
});
883+
do_with_trace(
884+
program.exe.clone(),
885+
program.vm_config.clone(),
886+
inputs,
887+
|vm, pk, ctx| {
888+
debug_proving_ctx(vm, pk, &ctx);
889+
},
890+
);
886891
} else {
887892
let mut app_prover = sdk.app_prover(exe.clone())?;
888893

openvm/src/trace_generation.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
1+
use std::sync::Arc;
2+
13
use openvm_circuit::arch::{
24
execution_mode::Segment, PreflightExecutionOutput, VirtualMachine, VmCircuitConfig, VmInstance,
35
};
6+
use openvm_instructions::exe::VmExe;
47
use openvm_sdk::{
58
config::{AppConfig, DEFAULT_APP_LOG_BLOWUP},
69
prover::vm::new_local_prover,
710
StdIn,
811
};
912
use openvm_stark_backend::{keygen::types::MultiStarkProvingKey, prover::types::ProvingContext};
10-
use openvm_stark_sdk::{config::FriParameters, engine::StarkEngine};
13+
use openvm_stark_sdk::{config::FriParameters, engine::StarkEngine, p3_baby_bear::BabyBear};
1114
use tracing::info_span;
1215

13-
use crate::{BabyBearSC, CompiledProgram, SpecializedConfigCpuBuilder};
16+
use crate::{BabyBearSC, SpecializedConfig, SpecializedConfigCpuBuilder};
1417

1518
#[cfg(not(feature = "cuda"))]
1619
use crate::PowdrSdkCpu as PowdrSdk;
@@ -25,17 +28,15 @@ use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Engine;
2528
/// Given a program and input, generates the trace segment by segment and calls the provided
2629
/// callback with the VM, proving key, and proving context (containing the trace) for each segment.
2730
pub fn do_with_trace(
28-
program: &CompiledProgram,
31+
exe: Arc<VmExe<BabyBear>>,
32+
vm_config: SpecializedConfig,
2933
inputs: StdIn,
3034
mut callback: impl FnMut(
3135
&VirtualMachine<BabyBearPoseidon2Engine, SpecializedConfigCpuBuilder>,
3236
&MultiStarkProvingKey<BabyBearSC>,
3337
ProvingContext<<BabyBearPoseidon2Engine as StarkEngine>::PB>,
3438
),
3539
) {
36-
let exe = &program.exe;
37-
let vm_config = program.vm_config.clone();
38-
3940
// Set app configuration
4041
let app_fri_params =
4142
FriParameters::standard_with_100_bits_conjectured_security(DEFAULT_APP_LOG_BLOWUP);

0 commit comments

Comments
 (0)