diff --git a/autoprecompiles/src/empirical_constraints.rs b/autoprecompiles/src/empirical_constraints.rs index 8a5f8b018c..3de07e6612 100644 --- a/autoprecompiles/src/empirical_constraints.rs +++ b/autoprecompiles/src/empirical_constraints.rs @@ -1,25 +1,28 @@ use std::collections::btree_map::Entry; -use std::collections::{BTreeMap, BTreeSet, HashMap}; +use std::collections::{BTreeMap, HashMap}; use std::fmt::Debug; use std::hash::Hash; use itertools::Itertools; use serde::{Deserialize, Serialize}; +pub use crate::equivalence_classes::{EquivalenceClass, EquivalenceClasses}; + /// "Constraints" that were inferred from execution statistics. They hold empirically /// (most of the time), but are not guaranteed to hold in all cases. -#[derive(Serialize, Deserialize, Clone, Default, Debug)] +#[derive(Serialize, Default, Debug)] pub struct EmpiricalConstraints { /// For each program counter, the range constraints for each column. /// The range might not hold in 100% of cases. pub column_ranges_by_pc: BTreeMap>, /// For each basic block (identified by its starting PC), the equivalence classes of columns. /// Each equivalence class is a list of (instruction index in block, column index). - pub equivalence_classes_by_block: BTreeMap>>, + pub equivalence_classes_by_block: BTreeMap>, + pub debug_info: DebugInfo, } /// Debug information mapping AIR ids to program counters and column names. -#[derive(Serialize, Deserialize, Default)] +#[derive(Serialize, Deserialize, Default, Debug)] pub struct DebugInfo { /// Mapping from program counter to AIR id. pub air_id_by_pc: BTreeMap, @@ -27,12 +30,6 @@ pub struct DebugInfo { pub column_names_by_air_id: BTreeMap>, } -#[derive(Serialize, Deserialize)] -pub struct EmpiricalConstraintsJson { - pub empirical_constraints: EmpiricalConstraints, - pub debug_info: DebugInfo, -} - impl EmpiricalConstraints { pub fn combine_with(&mut self, other: EmpiricalConstraints) { // Combine column ranges by PC @@ -52,15 +49,14 @@ impl EmpiricalConstraints { // Combine equivalence classes by block for (block_pc, classes) in other.equivalence_classes_by_block { - self.equivalence_classes_by_block + let existing = self + .equivalence_classes_by_block .entry(block_pc) - .and_modify(|existing_classes| { - let combined = - intersect_partitions(&[existing_classes.clone(), classes.clone()]); - *existing_classes = combined; - }) - .or_insert(classes); + .or_default(); + + *existing = intersect_partitions(vec![std::mem::take(existing), classes]); } + self.debug_info.combine_with(other.debug_info); } } @@ -88,16 +84,32 @@ fn merge_maps(map1: &mut BTreeMap, map2: BTreeMap Self { + Self { + row_idx, + column_idx, + } + } +} + /// Intersects multiple partitions of the same universe into a single partition. /// In other words, two elements are in the same equivalence class in the resulting partition /// if and only if they are in the same equivalence class in all input partitions. /// Singleton equivalence classes are omitted from the result. -pub fn intersect_partitions(partitions: &[BTreeSet>]) -> BTreeSet> -where - Id: Eq + Hash + Copy + Ord, -{ +pub fn intersect_partitions( + partitions: Vec>, +) -> EquivalenceClasses { // For each partition, build a map: Id -> class_index - let class_ids: Vec> = partitions + let class_ids: Vec> = partitions .iter() .map(|partition| { partition @@ -112,14 +124,14 @@ where partitions .iter() .flat_map(|partition| partition.iter()) - .flat_map(|class| class.iter().copied()) + .flat_map(|class| class.iter()) .unique() .filter_map(|id| { // Build the signature of the element: the list of class indices it belongs to // (one index per partition) class_ids .iter() - .map(|m| m.get(&id).cloned()) + .map(|m| m.get(id).cloned()) // If an element did not appear in any one of the partitions, it is // a singleton and we skip it. .collect::>>() @@ -128,16 +140,16 @@ where // Group elements by their signatures .into_group_map() .into_values() - // Remove singletons and convert to Set - .filter_map(|ids| (ids.len() > 1).then_some(ids.into_iter().collect())) + // Convert to set + .map(|ids| ids.into_iter().copied().collect()) .collect() } #[cfg(test)] mod tests { - use std::collections::BTreeSet; + use crate::empirical_constraints::EquivalenceClasses; - fn partition(sets: Vec>) -> BTreeSet> { + fn partition(sets: Vec>) -> EquivalenceClasses { sets.into_iter().map(|s| s.into_iter().collect()).collect() } @@ -156,7 +168,7 @@ mod tests { vec![6, 7, 8], ]); - let result = super::intersect_partitions(&[partition1, partition2]); + let result = super::intersect_partitions(vec![partition1, partition2]); let expected = partition(vec![vec![2, 3], vec![6, 7, 8]]); diff --git a/autoprecompiles/src/equivalence_classes.rs b/autoprecompiles/src/equivalence_classes.rs new file mode 100644 index 0000000000..70f98a41b6 --- /dev/null +++ b/autoprecompiles/src/equivalence_classes.rs @@ -0,0 +1,36 @@ +use std::collections::BTreeSet; + +use serde::Serialize; + +/// An equivalence class +pub type EquivalenceClass = BTreeSet; + +/// A collection of equivalence classes where all classes are guaranteed to have at least two elements +#[derive(Serialize, Debug, PartialEq, Eq)] +pub struct EquivalenceClasses { + inner: BTreeSet>, +} + +// TODO: derive +impl Default for EquivalenceClasses { + fn default() -> Self { + Self { + inner: Default::default(), + } + } +} + +impl FromIterator> for EquivalenceClasses { + fn from_iter>>(iter: I) -> Self { + // When collecting, we ignore classes with 0 or 1 elements as they are useless + Self { + inner: iter.into_iter().filter(|class| class.len() > 1).collect(), + } + } +} + +impl EquivalenceClasses { + pub fn iter(&self) -> impl Iterator> { + self.inner.iter() + } +} diff --git a/autoprecompiles/src/lib.rs b/autoprecompiles/src/lib.rs index 86a18f766d..3ad623b0e4 100644 --- a/autoprecompiles/src/lib.rs +++ b/autoprecompiles/src/lib.rs @@ -41,6 +41,7 @@ mod stats_logger; pub mod symbolic_machine_generator; pub use pgo::{PgoConfig, PgoType}; pub use powdr_constraint_solver::inliner::DegreeBound; +pub mod equivalence_classes; pub mod trace_handler; #[derive(Clone)] diff --git a/cli-openvm/src/main.rs b/cli-openvm/src/main.rs index 053095941b..48bb9d0569 100644 --- a/cli-openvm/src/main.rs +++ b/cli-openvm/src/main.rs @@ -3,7 +3,6 @@ use metrics_tracing_context::{MetricsLayer, TracingContextLayer}; use metrics_util::{debugging::DebuggingRecorder, layers::Layer}; use openvm_sdk::StdIn; use openvm_stark_sdk::bench::serialize_metric_snapshot; -use powdr_autoprecompiles::empirical_constraints::EmpiricalConstraintsJson; use powdr_autoprecompiles::pgo::{pgo_config, PgoType}; use powdr_autoprecompiles::PowdrConfig; use powdr_openvm::{compile_openvm, default_powdr_openvm_config, CompiledProgram, GuestOptions}; @@ -311,7 +310,7 @@ fn maybe_compute_empirical_constraints( "Optimistic precompiles are not implemented yet. Computing empirical constraints..." ); - let (empirical_constraints, debug_info) = + let empirical_constraints = detect_empirical_constraints(guest_program, powdr_config.degree_bound, vec![stdin]); if let Some(path) = &powdr_config.apc_candidates_dir_path { @@ -319,11 +318,7 @@ fn maybe_compute_empirical_constraints( "Saving empirical constraints debug info to {}/empirical_constraints.json", path.display() ); - let export = EmpiricalConstraintsJson { - empirical_constraints: empirical_constraints.clone(), - debug_info, - }; - let json = serde_json::to_string_pretty(&export).unwrap(); + let json = serde_json::to_string_pretty(&empirical_constraints).unwrap(); std::fs::write(path.join("empirical_constraints.json"), json).unwrap(); } } diff --git a/openvm/src/empirical_constraints.rs b/openvm/src/empirical_constraints.rs index ae2ded80e2..ddb01f6cc5 100644 --- a/openvm/src/empirical_constraints.rs +++ b/openvm/src/empirical_constraints.rs @@ -1,24 +1,38 @@ use itertools::Itertools; -use openvm_circuit::arch::VmCircuitConfig; use openvm_sdk::StdIn; use openvm_stark_backend::p3_maybe_rayon::prelude::IntoParallelIterator; use openvm_stark_backend::p3_maybe_rayon::prelude::ParallelIterator; use openvm_stark_sdk::openvm_stark_backend::p3_field::PrimeField32; +use openvm_stark_sdk::p3_baby_bear::BabyBear; +use powdr_autoprecompiles::blocks::BasicBlock; +use powdr_autoprecompiles::bus_map::BusType; +use powdr_autoprecompiles::empirical_constraints::BlockCell; +use powdr_autoprecompiles::empirical_constraints::EquivalenceClasses; use powdr_autoprecompiles::empirical_constraints::{ intersect_partitions, DebugInfo, EmpiricalConstraints, }; +use powdr_autoprecompiles::expression::AlgebraicEvaluator; +use powdr_autoprecompiles::expression::RowEvaluator; use powdr_autoprecompiles::DegreeBound; use std::collections::btree_map::Entry; +use std::collections::BTreeMap; use std::collections::HashMap; -use std::collections::{BTreeMap, BTreeSet}; +use std::iter::once; -use crate::trace_generation::do_with_trace; +use crate::bus_map::default_openvm_bus_map; +use crate::trace_generation::do_with_cpu_trace; use crate::{CompiledProgram, OriginalCompiledProgram}; +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] +struct Timestamp { + segment_id: usize, + value: u32, +} + #[derive(Debug)] struct Row { pc: u32, - timestamp: (u32, u32), + timestamp: Timestamp, cells: Vec, } @@ -37,11 +51,8 @@ impl Trace { }) } - fn rows_by_time(&self) -> Vec<&Row> { - self.rows - .iter() - .sorted_by_key(|row| row.timestamp) - .collect() + fn rows_sorted_by_time(&self) -> impl Iterator { + self.rows.iter().sorted_by_key(|row| &row.timestamp) } } @@ -49,18 +60,14 @@ pub fn detect_empirical_constraints( program: &OriginalCompiledProgram, degree_bound: DegreeBound, inputs: Vec, -) -> (EmpiricalConstraints, DebugInfo) { +) -> EmpiricalConstraints { tracing::info!("Collecting empirical constraints..."); let blocks = program.collect_basic_blocks(degree_bound.identities); - let instruction_counts = blocks - .iter() - .map(|block| (block.start_pc, block.statements.len())) - .collect(); // Collect trace, without any autoprecompiles. let program = program.compiled_program(Vec::new(), degree_bound.identities); - let mut constraint_detector = ConstraintDetector::new(instruction_counts); + let mut constraint_detector = ConstraintDetector::new(&blocks); let num_inputs = inputs.len(); for (i, input) in inputs.into_iter().enumerate() { @@ -69,7 +76,7 @@ pub fn detect_empirical_constraints( // If this becomes a RAM issue, we can also pass individual segments to process_trace. // The advantage of the current approach is that the percentiles can be computed more accurately. tracing::info!(" Collecting trace..."); - let (trace, new_debug_info) = collect_trace(&program, input); + let (trace, new_debug_info) = collect_trace(&program, input, degree_bound.identities); tracing::info!(" Detecting constraints..."); constraint_detector.process_trace(trace, new_debug_info); } @@ -78,63 +85,77 @@ pub fn detect_empirical_constraints( constraint_detector.finalize() } -fn collect_trace(program: &CompiledProgram, inputs: StdIn) -> (Trace, DebugInfo) { +use openvm_stark_backend::p3_field::FieldAlgebra; + +fn collect_trace( + program: &CompiledProgram, + inputs: StdIn, + degree_bound: usize, +) -> (Trace, DebugInfo) { let mut trace = Trace::default(); let mut debug_info = DebugInfo::default(); - let mut seg_idx = 0; - - do_with_trace(program, inputs, |vm, _pk, ctx| { - let global_airs = vm - .config() - .create_airs() - .unwrap() - .into_airs() - .enumerate() - .collect::>(); - for (air_id, proving_context) in &ctx.per_air { - let air = &global_airs[air_id]; - let Some(column_names) = air.columns() else { - // Instruction chips always have column names. - continue; - }; + do_with_cpu_trace(program, inputs, |segment_id, _vm, _pk, ctx| { + let airs = program.vm_config.sdk.airs(degree_bound).unwrap(); + for (air_id, proving_context) in &ctx.per_air { if !proving_context.cached_mains.is_empty() { // Instruction chips always have a cached main. continue; } let main = proving_context.common_main.as_ref().unwrap(); - assert_eq!(main.width, column_names.len()); - - // Instruction chips have a PC and time stamp - let find_col = |name: &str| -> Option { - column_names.iter().position(|col_name| { - col_name == name || col_name == &format!("inner__{}", name) + let machine = &airs.machine_by_air_id.get(air_id).unwrap().machine; + + // Find the execution bus interation + // This assumes there is exactly one, which is the case for instruction chips + let execution_bus_interaction = machine + .bus_interactions + .iter() + .find(|interaction| { + interaction.id + == default_openvm_bus_map() + .get_bus_id(&BusType::ExecutionBridge) + .unwrap() }) - }; - let Some(pc_index) = find_col("from_state__pc") else { - continue; - }; - let ts_index = find_col("from_state__timestamp").unwrap(); + .unwrap(); for row in main.row_slices() { - let row = row.iter().map(|v| v.as_canonical_u32()).collect::>(); - let pc_value = row[pc_index]; - let ts_value = row[ts_index]; + // Create an evaluator over this row + let evaluator = RowEvaluator::new(row); - if pc_value == 0 { - // Padding row! + // Evaluate the execution bus interaction + let execution = evaluator.eval_bus_interaction(execution_bus_interaction); + + // `is_valid` is the multiplicity + let is_valid = execution.mult; + if is_valid == BabyBear::ZERO { + // If `is_valid` is zero, this is a padding row continue; } + // Recover the values of the pc and timestamp + let [pc, timestamp] = execution + .args + .map(|v| v.as_canonical_u32()) + .collect_vec() + .try_into() + .unwrap(); + + // Convert the row to u32s + // TODO: is this necessary? + let row = row.iter().map(|v| v.as_canonical_u32()).collect(); + let row = Row { cells: row, - pc: pc_value, - timestamp: (seg_idx, ts_value), + pc, + timestamp: Timestamp { + segment_id, + value: timestamp, + }, }; trace.rows.push(row); - match debug_info.air_id_by_pc.entry(pc_value) { + match debug_info.air_id_by_pc.entry(pc) { Entry::Vacant(entry) => { entry.insert(*air_id); } @@ -143,14 +164,13 @@ fn collect_trace(program: &CompiledProgram, inputs: StdIn) -> (Trace, DebugInfo) } } if !debug_info.column_names_by_air_id.contains_key(air_id) { - debug_info - .column_names_by_air_id - .insert(*air_id, column_names.clone()); + debug_info.column_names_by_air_id.insert( + *air_id, + machine.main_columns().map(|r| (*r.name).clone()).collect(), + ); } } } - - seg_idx += 1; }) .unwrap(); (trace, debug_info) @@ -158,25 +178,51 @@ fn collect_trace(program: &CompiledProgram, inputs: StdIn) -> (Trace, DebugInfo) struct ConstraintDetector { /// Mapping from block PC to number of instructions in that block - instruction_counts: HashMap, + instruction_count_by_start_pc: HashMap, empirical_constraints: EmpiricalConstraints, - debug_info: DebugInfo, +} + +struct ConcreteBlock<'a> { + rows: Vec<&'a Row>, +} + +impl<'a> ConcreteBlock<'a> { + fn equivalence_classes(&self) -> EquivalenceClasses { + self.rows + .iter() + .enumerate() + // Map each cell to a (value, (instruction_index, col_index)) pair + .flat_map(|(instruction_index, row)| { + row.cells + .iter() + .enumerate() + .map(|(col_index, v)| (*v, BlockCell::new(instruction_index, col_index))) + .collect::>() + }) + // Group by value + .into_group_map() + .into_values() + .map(|cells| cells.into_iter().collect()) + .collect() + } } impl ConstraintDetector { - pub fn new(instruction_counts: HashMap) -> Self { + pub fn new(blocks: &[BasicBlock]) -> Self { Self { - instruction_counts, + instruction_count_by_start_pc: blocks + .iter() + .map(|block| (block.start_pc, block.statements.len())) + .collect(), empirical_constraints: EmpiricalConstraints::default(), - debug_info: DebugInfo::default(), } } - pub fn finalize(self) -> (EmpiricalConstraints, DebugInfo) { - (self.empirical_constraints, self.debug_info) + pub fn finalize(self) -> EmpiricalConstraints { + self.empirical_constraints } - pub fn process_trace(&mut self, trace: Trace, new_debug_info: DebugInfo) { + pub fn process_trace(&mut self, trace: Trace, debug_info: DebugInfo) { // Compute empirical constraints from the current trace tracing::info!(" Detecting equivalence classes by block..."); let equivalence_classes_by_block = self.generate_equivalence_classes_by_block(&trace); @@ -185,12 +231,12 @@ impl ConstraintDetector { let new_empirical_constraints = EmpiricalConstraints { column_ranges_by_pc, equivalence_classes_by_block, + debug_info, }; // Combine the new empirical constraints and debug info with the existing ones self.empirical_constraints .combine_with(new_empirical_constraints); - self.debug_info.combine_with(new_debug_info); } fn detect_column_ranges_by_pc(&self, trace: Trace) -> BTreeMap> { @@ -226,7 +272,7 @@ impl ConstraintDetector { fn generate_equivalence_classes_by_block( &self, trace: &Trace, - ) -> BTreeMap>> { + ) -> BTreeMap> { tracing::info!(" Segmenting trace into blocks..."); let blocks = self.get_blocks(trace); tracing::info!(" Finding equivalence classes..."); @@ -235,18 +281,12 @@ impl ConstraintDetector { .map(|(block_id, block_instances)| { // Segment each block instance into equivalence classes let classes = block_instances - .into_iter() - .map(|block| self.block_equivalence_classes(block)) - .collect::>(); + .iter() + .map(ConcreteBlock::equivalence_classes) + .collect(); // Intersect the equivalence classes across all instances of the block - let intersected = intersect_partitions(&classes); - - // Remove singleton classes - let intersected = intersected - .into_iter() - .filter(|class| class.len() > 1) - .collect::>(); + let intersected = intersect_partitions(classes); (block_id, intersected) }) @@ -254,60 +294,42 @@ impl ConstraintDetector { } /// Segments a trace into basic blocks. - /// Returns a mapping from block ID to all instances of that block in the trace. - fn get_blocks<'a>(&self, trace: &'a Trace) -> BTreeMap>> { - let mut block_rows = BTreeMap::new(); - let mut row_index = 0; - let rows_by_time = trace.rows_by_time(); + /// Returns a mapping from block start pc to all instances of that block in the trace. + fn get_blocks<'a>(&self, trace: &'a Trace) -> BTreeMap>> { + trace + .rows_sorted_by_time() + // take entire blocks from the rows + .batching(|it| { + let first = it.next()?; + let block_id = first.pc as u64; - while row_index < rows_by_time.len() { - let first_row = rows_by_time[row_index]; - let block_id = first_row.pc as u64; + if let Some(&count) = self.instruction_count_by_start_pc.get(&block_id) { + let rows = once(first).chain(it.take(count - 1)).collect_vec(); - if let Some(instruction_count) = self.instruction_counts.get(&block_id) { - let block_row_slice = &rows_by_time[row_index..row_index + instruction_count]; + for (r1, r2) in rows.iter().tuple_windows() { + assert_eq!(r2.pc, r1.pc + 4); + } - for (row1, row2) in block_row_slice.iter().tuple_windows() { - assert_eq!(row2.pc, row1.pc + 4); + Some(Some((block_id, ConcreteBlock { rows }))) + } else { + // Single instruction block, yield `None` to be filtered. + Some(None) } - + }) + // filter out single instruction blocks + .flatten() + // collect by start_pc + .fold(Default::default(), |mut block_rows, (block_id, chunk)| { + block_rows.entry(block_id).or_insert(Vec::new()).push(chunk); block_rows - .entry(block_id) - .or_insert(Vec::new()) - .push(block_row_slice.to_vec()); - row_index += instruction_count; - } else { - // Single instruction block, ignore. - row_index += 1; - } - } - - block_rows - } - - fn block_equivalence_classes(&self, block: Vec<&Row>) -> BTreeSet> { - block - .into_iter() - .enumerate() - // Map each cell to a (value, (instruction_index, col_index)) pair - .flat_map(|(instruction_index, row)| { - row.cells - .iter() - .enumerate() - .map(|(col_index, v)| (*v, (instruction_index, col_index))) - .collect::>() }) - // Group by value - .into_group_map() - .values() - // Convert to set - .map(|v| v.clone().into_iter().collect()) - .collect() } } #[cfg(test)] mod tests { + use powdr_autoprecompiles::empirical_constraints::EquivalenceClass; + use super::*; fn make_trace(rows_by_time_with_pc: Vec<(u32, Vec)>) -> Trace { @@ -318,34 +340,26 @@ mod tests { .map(|(clk, (pc, cells))| Row { cells, pc, - timestamp: (0, clk as u32), + timestamp: Timestamp { + segment_id: 0, + value: clk as u32, + }, }) .collect(), } } - fn assert_equivalence_classes_equal( - actual: BTreeSet>, - expected: Vec>, - ) { - assert_eq!(actual.len(), expected.len()); - let mut actual = actual.into_iter(); - for expected_class in expected { - let actual_class = actual.next().unwrap(); - let expected_class_set: BTreeSet<(usize, usize)> = expected_class.into_iter().collect(); - assert_eq!(actual_class, expected_class_set); - } - assert!(actual.next().is_none()); - } - #[test] fn test_constraint_detector() { // Assume the following test program: // ADDI x1, x1, 1 // note how the second operand is always 1 // BLT x1, x2, -4 // Note how the first operand is always equal to the result of the previous ADDI - let instruction_counts = vec![(0, 2)].into_iter().collect(); - let mut detector = ConstraintDetector::new(instruction_counts); + let instruction_counts = vec![BasicBlock { + start_pc: 0, + statements: vec![(), ()], + }]; + let mut detector = ConstraintDetector::new(&instruction_counts); let trace1 = make_trace(vec![ (0, vec![1, 0, 1]), // ADDI: 0 + 1 = 1 @@ -355,7 +369,7 @@ mod tests { ]); detector.process_trace(trace1, DebugInfo::default()); - let (empirical_constraints, _debug_info) = detector.finalize(); + let empirical_constraints = detector.finalize(); assert_eq!( empirical_constraints.column_ranges_by_pc.get(&0), @@ -371,16 +385,16 @@ mod tests { let equivalence_classes = empirical_constraints .equivalence_classes_by_block .get(&0) - .unwrap() - .clone(); + .unwrap(); println!("Equivalence classes: {:?}", equivalence_classes); - assert_equivalence_classes_equal( - equivalence_classes, - vec![ - // The result of the first instruction (col 0) is always equal to the - // first operand of the second instruction (col 1) - vec![(0, 0), (1, 1)], - ], - ); + let expected: EquivalenceClasses<_> = once( + // The result of the first instruction (col 0) is always equal to the + // first operand of the second instruction (col 1) + [BlockCell::new(0, 0), BlockCell::new(1, 1)] + .into_iter() + .collect::>(), + ) + .collect(); + assert_eq!(*equivalence_classes, expected,); } } diff --git a/openvm/src/extraction_utils.rs b/openvm/src/extraction_utils.rs index 07a4d3ba6c..baa56e3dbc 100644 --- a/openvm/src/extraction_utils.rs +++ b/openvm/src/extraction_utils.rs @@ -50,32 +50,35 @@ use crate::utils::symbolic_to_algebraic; // TODO: Use ` as FieldExtensionAlgebra>>::D` instead after fixing p3 dependency const EXT_DEGREE: usize = 4; +#[derive(Clone, Serialize, Deserialize)] +pub struct OriginalMachine { + pub machine: SymbolicMachine, + pub name: String, + pub metrics: AirMetrics, +} + #[derive(Clone, Serialize, Deserialize, Default)] pub struct OriginalAirs { - pub(crate) opcode_to_air: HashMap, - pub(crate) air_name_to_machine: BTreeMap, AirMetrics)>, + pub(crate) air_id_by_opcode: HashMap, + pub(crate) machine_by_air_id: BTreeMap>, } impl InstructionHandler for OriginalAirs { type Field = F; type Instruction = Instr; - type AirId = String; + type AirId = usize; fn get_instruction_air_and_id( &self, instruction: &Self::Instruction, ) -> (Self::AirId, &SymbolicMachine) { - let id = self - .opcode_to_air - .get(&instruction.0.opcode) - .unwrap() - .clone(); - let air = &self.air_name_to_machine.get(&id).unwrap().0; + let id = *self.air_id_by_opcode.get(&instruction.0.opcode).unwrap(); + let air = &self.machine_by_air_id.get(&id).unwrap().machine; (id, air) } fn is_allowed(&self, instruction: &Self::Instruction) -> bool { - self.opcode_to_air.contains_key(&instruction.0.opcode) + self.air_id_by_opcode.contains_key(&instruction.0.opcode) } fn is_branching(&self, instruction: &Self::Instruction) -> bool { @@ -95,59 +98,63 @@ impl OriginalAirs { pub fn insert_opcode( &mut self, opcode: VmOpcode, - air_name: String, - machine: impl Fn() -> Result<(SymbolicMachine, AirMetrics), UnsupportedOpenVmReferenceError>, + air_id: usize, + machine: impl Fn() -> Result, UnsupportedOpenVmReferenceError>, ) -> Result<(), UnsupportedOpenVmReferenceError> { - if self.opcode_to_air.contains_key(&opcode) { + if self.air_id_by_opcode.contains_key(&opcode) { panic!("Opcode {opcode} already exists"); } // Insert the machine only if `air_name` isn't already present - if !self.air_name_to_machine.contains_key(&air_name) { + if let std::collections::btree_map::Entry::Vacant(e) = self.machine_by_air_id.entry(air_id) + { let machine_instance = machine()?; - self.air_name_to_machine - .insert(air_name.clone(), machine_instance); + e.insert(machine_instance); } - self.opcode_to_air.insert(opcode, air_name); + self.air_id_by_opcode.insert(opcode, air_id); + Ok(()) } pub fn get_instruction_metrics(&self, opcode: VmOpcode) -> Option<&AirMetrics> { - self.opcode_to_air.get(&opcode).and_then(|air_name| { - self.air_name_to_machine - .get(air_name) - .map(|(_, metrics)| metrics) + self.air_id_by_opcode.get(&opcode).and_then(|air_id| { + self.machine_by_air_id + .get(air_id) + .map(|machine| &machine.metrics) }) } pub fn allow_list(&self) -> Vec { - self.opcode_to_air.keys().cloned().collect() + self.air_id_by_opcode.keys().cloned().collect() } pub fn airs_by_name(&self) -> impl Iterator)> { - self.air_name_to_machine - .iter() - .map(|(name, (machine, _))| (name, machine)) + self.machine_by_air_id + .values() + .map(|machine| (&machine.name, &machine.machine)) } } /// For each air name, the dimension of a record arena needed to store the /// records for a single APC call. -pub fn record_arena_dimension_by_air_name_per_apc_call( +pub fn record_arena_dimension_by_air_id_per_apc_call( apc: &Apc>, air_by_opcode_id: &OriginalAirs, -) -> BTreeMap { +) -> BTreeMap { apc.instructions().iter().map(|instr| &instr.0.opcode).fold( BTreeMap::new(), |mut acc, opcode| { // Get the air name for this opcode - let air_name = air_by_opcode_id.opcode_to_air.get(opcode).unwrap(); + let air_id = air_by_opcode_id.air_id_by_opcode.get(opcode).unwrap(); // Increment the height for this air name, initializing if necessary - acc.entry(air_name.clone()) + acc.entry(*air_id) .or_insert_with(|| { - let (_, air_metrics) = - air_by_opcode_id.air_name_to_machine.get(air_name).unwrap(); + let air_metrics = &air_by_opcode_id + .machine_by_air_id + .get(air_id) + .unwrap() + .metrics; RecordArenaDimension { height: 0, @@ -300,11 +307,12 @@ impl OriginalVmConfig { }) .map(|(op, executor_id)| { let insertion_index = chip_inventory.executor_idx_to_insertion_idx[executor_id]; - let air_ref = &chip_inventory.airs().ext_airs()[insertion_index]; - (op, air_ref) + (op, insertion_index) }) // find executor for opcode - .try_fold(OriginalAirs::default(), |mut airs, (op, air_ref)| { - airs.insert_opcode(op, air_ref.name(), || { + .try_fold(OriginalAirs::default(), |mut airs, (op, insertion_idx)| { + airs.insert_opcode(op, insertion_idx, || { + let air_ref = &chip_inventory.airs().ext_airs()[insertion_idx]; + let name = air_ref.name(); let columns = get_columns(air_ref.clone()); let constraints = get_constraints(air_ref.clone()); let metrics = get_air_metrics(air_ref.clone(), max_degree); @@ -321,14 +329,17 @@ impl OriginalVmConfig { .map(|expr| openvm_bus_interaction_to_powdr(expr, &columns)) .collect::>()?; - Ok(( - SymbolicMachine { - constraints: powdr_exprs.into_iter().map(Into::into).collect(), - bus_interactions: powdr_bus_interactions, - derived_columns: vec![], - }, + let machine = SymbolicMachine { + constraints: powdr_exprs.into_iter().map(Into::into).collect(), + bus_interactions: powdr_bus_interactions, + derived_columns: vec![], + }; + + Ok(OriginalMachine { + machine, + name, metrics, - )) + }) })?; Ok(airs) diff --git a/openvm/src/lib.rs b/openvm/src/lib.rs index 6a43ce2452..5a8a8407e3 100644 --- a/openvm/src/lib.rs +++ b/openvm/src/lib.rs @@ -56,7 +56,7 @@ use crate::customize_exe::OpenVmApcCandidate; use crate::powdr_extension::chip::PowdrAir; pub use crate::program::Prog; pub use crate::program::{CompiledProgram, OriginalCompiledProgram}; -use crate::trace_generation::do_with_trace; +use crate::trace_generation::do_with_cpu_trace; #[cfg(test)] use crate::extraction_utils::AirWidthsDiff; @@ -776,7 +776,7 @@ pub fn prove( segment_height: Option, // uses the default height if None ) -> Result<(), Box> { if mock { - do_with_trace(program, inputs, |vm, pk, ctx| { + do_with_cpu_trace(program, inputs, |_, vm, pk, ctx| { debug_proving_ctx(vm, pk, &ctx); })?; } else { diff --git a/openvm/src/powdr_extension/executor/mod.rs b/openvm/src/powdr_extension/executor/mod.rs index ef9613e407..7b9354c4b6 100644 --- a/openvm/src/powdr_extension/executor/mod.rs +++ b/openvm/src/powdr_extension/executor/mod.rs @@ -8,7 +8,7 @@ use std::{ use crate::{ extraction_utils::{ - record_arena_dimension_by_air_name_per_apc_call, OriginalAirs, OriginalVmConfig, + record_arena_dimension_by_air_id_per_apc_call, OriginalAirs, OriginalVmConfig, }, Instr, }; @@ -94,10 +94,10 @@ impl OriginalArenas { /// Returns the arena of the given air name. /// - Panics if the arenas are not initialized. - pub fn take_arena(&mut self, air_name: &str) -> Option { + pub fn take_arena(&mut self, air_id: usize) -> Option { match self { OriginalArenas::Uninitialized => panic!("original arenas are uninitialized"), - OriginalArenas::Initialized(initialized) => initialized.take_arena(air_name), + OriginalArenas::Initialized(initialized) => initialized.take_arena(air_id), } } @@ -125,7 +125,7 @@ impl OriginalArenas { #[derive(Default)] pub struct InitializedOriginalArenas { arenas: Vec>, - air_name_to_arena_index: HashMap, + air_id_to_arena_index: HashMap, pub number_of_calls: usize, } @@ -137,22 +137,22 @@ impl InitializedOriginalArenas { apc: &Arc>>, ) -> Self { let record_arena_dimensions = - record_arena_dimension_by_air_name_per_apc_call(apc, original_airs); - let (air_name_to_arena_index, arenas) = + record_arena_dimension_by_air_id_per_apc_call(apc, original_airs); + let (air_id_to_arena_index, arenas) = record_arena_dimensions.into_iter().enumerate().fold( (HashMap::new(), Vec::new()), |(mut air_name_to_arena_index, mut arenas), ( idx, ( - air_name, + air_id, RecordArenaDimension { height: num_calls, width: air_width, }, ), )| { - air_name_to_arena_index.insert(air_name, idx); + air_name_to_arena_index.insert(air_id, idx); arenas.push(Some(A::with_capacity( num_calls * apc_call_count_estimate, air_width, @@ -163,7 +163,7 @@ impl InitializedOriginalArenas { Self { arenas, - air_name_to_arena_index, + air_id_to_arena_index, // This is the actual number of calls, which we don't know yet. It will be updated during preflight execution. number_of_calls: 0, } @@ -177,8 +177,8 @@ impl InitializedOriginalArenas { .expect("arena missing for index") } - fn take_arena(&mut self, air_name: &str) -> Option { - let index = *self.air_name_to_arena_index.get(air_name)?; + fn take_arena(&mut self, air_id: usize) -> Option { + let index = *self.air_id_to_arena_index.get(&air_id)?; self.arenas[index].take() } } @@ -574,10 +574,10 @@ impl PowdrExecutor { let executor_inventory = base_config.sdk_config.sdk.create_executors().unwrap(); let arena_index_by_name = - record_arena_dimension_by_air_name_per_apc_call(apc.as_ref(), &air_by_opcode_id) + record_arena_dimension_by_air_id_per_apc_call(apc.as_ref(), &air_by_opcode_id) .iter() .enumerate() - .map(|(idx, (name, _))| (name.clone(), idx)) + .map(|(idx, (name, _))| (*name, idx)) .collect::>(); let cached_instructions_meta = apc @@ -590,7 +590,7 @@ impl PowdrExecutor { .expect("missing executor for opcode") as usize; let air_name = air_by_opcode_id - .opcode_to_air + .air_id_by_opcode .get(&instruction.0.opcode) .expect("missing air for opcode"); let arena_index = *arena_index_by_name diff --git a/openvm/src/powdr_extension/trace_generator/cpu/mod.rs b/openvm/src/powdr_extension/trace_generator/cpu/mod.rs index 94d9ef6d3e..546f5b380d 100644 --- a/openvm/src/powdr_extension/trace_generator/cpu/mod.rs +++ b/openvm/src/powdr_extension/trace_generator/cpu/mod.rs @@ -119,16 +119,14 @@ impl PowdrTraceGeneratorCpu { .inventory }; - let dummy_trace_by_air_name: HashMap> = chip_inventory + let dummy_trace_by_air_id: HashMap> = chip_inventory .chips() .iter() .enumerate() .rev() - .filter_map(|(insertion_idx, chip)| { - let air_name = chip_inventory.airs().ext_airs()[insertion_idx].name(); - + .filter_map(|(air_id, chip)| { let record_arena = { - match original_arenas.take_arena(&air_name) { + match original_arenas.take_arena(air_id) { Some(ra) => ra, None => return None, // skip this iteration, because we only have record arena for chips that are used } @@ -136,7 +134,7 @@ impl PowdrTraceGeneratorCpu { let shared_trace = chip.generate_proving_ctx(record_arena).common_main.unwrap(); - Some((air_name, SharedCpuTrace::from(shared_trace))) + Some((air_id, SharedCpuTrace::from(shared_trace))) }) .collect(); @@ -146,7 +144,7 @@ impl PowdrTraceGeneratorCpu { apc_poly_id_to_index, columns_to_compute, } = generate_trace( - &dummy_trace_by_air_name, + &dummy_trace_by_air_id, &self.original_airs, num_apc_calls, &self.apc, diff --git a/openvm/src/trace_generation.rs b/openvm/src/trace_generation.rs index 96454e9b34..ff2efec21b 100644 --- a/openvm/src/trace_generation.rs +++ b/openvm/src/trace_generation.rs @@ -18,11 +18,12 @@ use crate::SpecializedConfigCpuBuilder as SpecializedConfigBuilder; use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Engine; /// Given a program and input, generates the trace segment by segment and calls the provided -/// callback with the VM, proving key, and proving context (containing the trace) for each segment. -pub fn do_with_trace( +/// callback with the segment index, VM, proving key, and proving context (containing the trace) for each segment. +pub fn do_with_cpu_trace( program: &CompiledProgram, inputs: StdIn, mut callback: impl FnMut( + usize, &VirtualMachine, &MultiStarkProvingKey, ProvingContext<::PB>, @@ -85,7 +86,7 @@ pub fn do_with_trace( let ctx = vm.generate_proving_ctx(system_records, record_arenas)?; - callback(&vm, &pk, ctx); + callback(seg_idx, &vm, &pk, ctx); } Ok(()) }