Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion autoprecompiles/src/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ where
>;
type CustomBusTypes: Clone + Display + Sync + Eq + PartialEq;
type ApcStats: Send + Sync;
type AirId: Eq + Hash + Send + Sync;

fn into_field(e: Self::PowdrField) -> Self::Field;

Expand Down
1 change: 0 additions & 1 deletion openvm/src/customize_exe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ impl<'a> Adapter for BabyBearOpenVmApcAdapter<'a> {
OpenVmMemoryBusInteraction<Self::PowdrField, V>;
type CustomBusTypes = OpenVmBusType;
type ApcStats = OvmApcStats;
type AirId = String;

fn into_field(e: Self::PowdrField) -> Self::Field {
openvm_stark_sdk::p3_baby_bear::BabyBear::from_canonical_u32(
Expand Down
94 changes: 55 additions & 39 deletions openvm/src/empirical_constraints.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
use itertools::Itertools;
use openvm_circuit::arch::VmCircuitConfig;
use openvm_sdk::StdIn;
use openvm_stark_backend::p3_field::FieldAlgebra;
use openvm_stark_backend::p3_maybe_rayon::prelude::IntoParallelIterator;
use openvm_stark_backend::p3_maybe_rayon::prelude::ParallelIterator;
use openvm_stark_sdk::openvm_stark_backend::p3_field::PrimeField32;
use openvm_stark_sdk::p3_baby_bear::BabyBear;
use powdr_autoprecompiles::bus_map::BusType;
use powdr_autoprecompiles::empirical_constraints::{
intersect_partitions, DebugInfo, EmpiricalConstraints,
};
use powdr_autoprecompiles::expression::AlgebraicEvaluator;
use powdr_autoprecompiles::expression::RowEvaluator;
use powdr_autoprecompiles::DegreeBound;
use std::collections::btree_map::Entry;
use std::collections::HashMap;
use std::collections::{BTreeMap, BTreeSet};

use crate::bus_map::default_openvm_bus_map;
use crate::trace_generation::do_with_trace;
use crate::{CompiledProgram, OriginalCompiledProgram};

Expand Down Expand Up @@ -69,7 +74,7 @@ pub fn detect_empirical_constraints(
// If this becomes a RAM issue, we can also pass individual segments to process_trace.
// The advantage of the current approach is that the percentiles can be computed more accurately.
tracing::info!(" Collecting trace...");
let (trace, new_debug_info) = collect_trace(&program, input);
let (trace, new_debug_info) = collect_trace(&program, input, degree_bound.identities);
tracing::info!(" Detecting constraints...");
constraint_detector.process_trace(trace, new_debug_info);
}
Expand All @@ -78,63 +83,73 @@ pub fn detect_empirical_constraints(
constraint_detector.finalize()
}

fn collect_trace(program: &CompiledProgram, inputs: StdIn) -> (Trace, DebugInfo) {
fn collect_trace(
program: &CompiledProgram,
inputs: StdIn,
degree_bound: usize,
) -> (Trace, DebugInfo) {
let mut trace = Trace::default();
let mut debug_info = DebugInfo::default();
let mut seg_idx = 0;

do_with_trace(program, inputs, |vm, _pk, ctx| {
let global_airs = vm
.config()
.create_airs()
.unwrap()
.into_airs()
.enumerate()
.collect::<HashMap<_, _>>();
do_with_trace(program, inputs, |_vm, _pk, ctx| {
let airs = program.vm_config.sdk.airs(degree_bound).unwrap();

for (air_id, proving_context) in &ctx.per_air {
let air = &global_airs[air_id];
let Some(column_names) = air.columns() else {
// Instruction chips always have column names.
continue;
};

if !proving_context.cached_mains.is_empty() {
// Instruction chips always have a cached main.
continue;
}
let main = proving_context.common_main.as_ref().unwrap();
assert_eq!(main.width, column_names.len());

// Instruction chips have a PC and time stamp
let find_col = |name: &str| -> Option<usize> {
column_names.iter().position(|col_name| {
col_name == name || col_name == &format!("inner__{}", name)
let machine = &airs.machine_by_air_id.get(air_id).unwrap().machine;

// Find the execution bus interation
// This assumes there is exactly one, which is the case for instruction chips
let execution_bus_interaction = machine
.bus_interactions
.iter()
.find(|interaction| {
interaction.id
== default_openvm_bus_map()
.get_bus_id(&BusType::ExecutionBridge)
.unwrap()
})
};
let Some(pc_index) = find_col("from_state__pc") else {
continue;
};
let ts_index = find_col("from_state__timestamp").unwrap();
.unwrap();

for row in main.row_slices() {
let row = row.iter().map(|v| v.as_canonical_u32()).collect::<Vec<_>>();
let pc_value = row[pc_index];
let ts_value = row[ts_index];
// Create an evaluator over this row
let evaluator = RowEvaluator::new(row);

// Evaluate the execution bus interaction
let execution = evaluator.eval_bus_interaction(execution_bus_interaction);

if pc_value == 0 {
// Padding row!
// `is_valid` is the multiplicity
let is_valid = execution.mult;
if is_valid == BabyBear::ZERO {
// If `is_valid` is zero, this is a padding row
continue;
}

// Recover the values of the pc and timestamp
let [pc, timestamp] = execution
.args
.map(|v| v.as_canonical_u32())
.collect_vec()
.try_into()
.unwrap();

// Convert the row to u32s
// TODO: is this necessary?
let row = row.iter().map(|v| v.as_canonical_u32()).collect();

let row = Row {
cells: row,
pc: pc_value,
timestamp: (seg_idx, ts_value),
pc,
timestamp: (seg_idx, timestamp),
};
trace.rows.push(row);

match debug_info.air_id_by_pc.entry(pc_value) {
match debug_info.air_id_by_pc.entry(pc) {
Entry::Vacant(entry) => {
entry.insert(*air_id);
}
Expand All @@ -143,9 +158,10 @@ fn collect_trace(program: &CompiledProgram, inputs: StdIn) -> (Trace, DebugInfo)
}
}
if !debug_info.column_names_by_air_id.contains_key(air_id) {
debug_info
.column_names_by_air_id
.insert(*air_id, column_names.clone());
debug_info.column_names_by_air_id.insert(
*air_id,
machine.main_columns().map(|r| (*r.name).clone()).collect(),
);
}
}
}
Expand Down
95 changes: 53 additions & 42 deletions openvm/src/extraction_utils.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::btree_map::Entry;
use std::collections::{BTreeMap, HashMap};
use std::sync::{Arc, Mutex};

Expand Down Expand Up @@ -50,32 +51,35 @@ use crate::utils::symbolic_to_algebraic;
// TODO: Use `<PackedChallenge<BabyBearSC> as FieldExtensionAlgebra<Val<BabyBearSC>>>::D` instead after fixing p3 dependency
const EXT_DEGREE: usize = 4;

#[derive(Clone, Serialize, Deserialize)]
pub struct OriginalMachine<F> {
pub machine: SymbolicMachine<F>,
pub name: String,
pub metrics: AirMetrics,
}

#[derive(Clone, Serialize, Deserialize, Default)]
pub struct OriginalAirs<F> {
pub(crate) opcode_to_air: HashMap<VmOpcode, String>,
pub(crate) air_name_to_machine: BTreeMap<String, (SymbolicMachine<F>, AirMetrics)>,
pub(crate) air_id_by_opcode: HashMap<VmOpcode, usize>,
pub(crate) machine_by_air_id: BTreeMap<usize, OriginalMachine<F>>,
}

impl<F> InstructionHandler for OriginalAirs<F> {
type Field = F;
type Instruction = Instr<F>;
type AirId = String;
type AirId = usize;

fn get_instruction_air_and_id(
&self,
instruction: &Self::Instruction,
) -> (Self::AirId, &SymbolicMachine<Self::Field>) {
let id = self
.opcode_to_air
.get(&instruction.0.opcode)
.unwrap()
.clone();
let air = &self.air_name_to_machine.get(&id).unwrap().0;
let id = *self.air_id_by_opcode.get(&instruction.0.opcode).unwrap();
let air = &self.machine_by_air_id.get(&id).unwrap().machine;
(id, air)
}

fn is_allowed(&self, instruction: &Self::Instruction) -> bool {
self.opcode_to_air.contains_key(&instruction.0.opcode)
self.air_id_by_opcode.contains_key(&instruction.0.opcode)
}

fn is_branching(&self, instruction: &Self::Instruction) -> bool {
Expand All @@ -95,59 +99,62 @@ impl<F> OriginalAirs<F> {
pub fn insert_opcode(
&mut self,
opcode: VmOpcode,
air_name: String,
machine: impl Fn() -> Result<(SymbolicMachine<F>, AirMetrics), UnsupportedOpenVmReferenceError>,
air_id: usize,
machine: impl Fn() -> Result<OriginalMachine<F>, UnsupportedOpenVmReferenceError>,
) -> Result<(), UnsupportedOpenVmReferenceError> {
if self.opcode_to_air.contains_key(&opcode) {
if self.air_id_by_opcode.contains_key(&opcode) {
panic!("Opcode {opcode} already exists");
}
// Insert the machine only if `air_name` isn't already present
if !self.air_name_to_machine.contains_key(&air_name) {
if let Entry::Vacant(e) = self.machine_by_air_id.entry(air_id) {
let machine_instance = machine()?;
self.air_name_to_machine
.insert(air_name.clone(), machine_instance);
e.insert(machine_instance);
}

self.opcode_to_air.insert(opcode, air_name);
self.air_id_by_opcode.insert(opcode, air_id);

Ok(())
}

pub fn get_instruction_metrics(&self, opcode: VmOpcode) -> Option<&AirMetrics> {
self.opcode_to_air.get(&opcode).and_then(|air_name| {
self.air_name_to_machine
.get(air_name)
.map(|(_, metrics)| metrics)
self.air_id_by_opcode.get(&opcode).and_then(|air_id| {
self.machine_by_air_id
.get(air_id)
.map(|machine| &machine.metrics)
})
}

pub fn allow_list(&self) -> Vec<VmOpcode> {
self.opcode_to_air.keys().cloned().collect()
self.air_id_by_opcode.keys().cloned().collect()
}

pub fn airs_by_name(&self) -> impl Iterator<Item = (&String, &SymbolicMachine<F>)> {
self.air_name_to_machine
.iter()
.map(|(name, (machine, _))| (name, machine))
self.machine_by_air_id
.values()
.map(|machine| (&machine.name, &machine.machine))
}
}

/// For each air name, the dimension of a record arena needed to store the
/// records for a single APC call.
pub fn record_arena_dimension_by_air_name_per_apc_call<F>(
pub fn record_arena_dimension_by_air_id_per_apc_call<F>(
apc: &Apc<F, Instr<F>>,
air_by_opcode_id: &OriginalAirs<F>,
) -> BTreeMap<String, RecordArenaDimension> {
) -> BTreeMap<usize, RecordArenaDimension> {
apc.instructions().iter().map(|instr| &instr.0.opcode).fold(
BTreeMap::new(),
|mut acc, opcode| {
// Get the air name for this opcode
let air_name = air_by_opcode_id.opcode_to_air.get(opcode).unwrap();
let air_id = air_by_opcode_id.air_id_by_opcode.get(opcode).unwrap();

// Increment the height for this air name, initializing if necessary
acc.entry(air_name.clone())
acc.entry(*air_id)
.or_insert_with(|| {
let (_, air_metrics) =
air_by_opcode_id.air_name_to_machine.get(air_name).unwrap();
let air_metrics = &air_by_opcode_id
.machine_by_air_id
.get(air_id)
.unwrap()
.metrics;

RecordArenaDimension {
height: 0,
Expand Down Expand Up @@ -300,11 +307,12 @@ impl OriginalVmConfig {
})
.map(|(op, executor_id)| {
let insertion_index = chip_inventory.executor_idx_to_insertion_idx[executor_id];
let air_ref = &chip_inventory.airs().ext_airs()[insertion_index];
(op, air_ref)
(op, insertion_index)
}) // find executor for opcode
.try_fold(OriginalAirs::default(), |mut airs, (op, air_ref)| {
airs.insert_opcode(op, air_ref.name(), || {
.try_fold(OriginalAirs::default(), |mut airs, (op, insertion_idx)| {
airs.insert_opcode(op, insertion_idx, || {
let air_ref = &chip_inventory.airs().ext_airs()[insertion_idx];
let name = air_ref.name();
let columns = get_columns(air_ref.clone());
let constraints = get_constraints(air_ref.clone());
let metrics = get_air_metrics(air_ref.clone(), max_degree);
Expand All @@ -321,14 +329,17 @@ impl OriginalVmConfig {
.map(|expr| openvm_bus_interaction_to_powdr(expr, &columns))
.collect::<Result<_, _>>()?;

Ok((
SymbolicMachine {
constraints: powdr_exprs.into_iter().map(Into::into).collect(),
bus_interactions: powdr_bus_interactions,
derived_columns: vec![],
},
let machine = SymbolicMachine {
constraints: powdr_exprs.into_iter().map(Into::into).collect(),
bus_interactions: powdr_bus_interactions,
derived_columns: vec![],
};

Ok(OriginalMachine {
machine,
name,
metrics,
))
})
})?;

Ok(airs)
Expand Down
Loading
Loading