Skip to content

Commit 88e712d

Browse files
committed
use symbolic machine in trace collection
1 parent bac0ffb commit 88e712d

File tree

1 file changed

+55
-30
lines changed

1 file changed

+55
-30
lines changed

openvm/src/empirical_constraints.rs

Lines changed: 55 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
11
use itertools::Itertools;
22
use openvm_circuit::arch::VmCircuitConfig;
33
use openvm_sdk::StdIn;
4+
use openvm_stark_backend::p3_field::FieldAlgebra;
45
use openvm_stark_backend::p3_maybe_rayon::prelude::IntoParallelIterator;
56
use openvm_stark_backend::p3_maybe_rayon::prelude::ParallelIterator;
67
use openvm_stark_sdk::openvm_stark_backend::p3_field::PrimeField32;
8+
use openvm_stark_sdk::p3_baby_bear::BabyBear;
9+
use powdr_autoprecompiles::bus_map::BusType;
710
use powdr_autoprecompiles::empirical_constraints::{
811
intersect_partitions, DebugInfo, EmpiricalConstraints,
912
};
13+
use powdr_autoprecompiles::expression::AlgebraicEvaluator;
14+
use powdr_autoprecompiles::expression::RowEvaluator;
1015
use powdr_autoprecompiles::DegreeBound;
1116
use std::collections::btree_map::Entry;
1217
use std::collections::HashMap;
1318
use std::collections::{BTreeMap, BTreeSet};
1419

20+
use crate::bus_map::default_openvm_bus_map;
1521
use crate::trace_generation::do_with_trace;
1622
use crate::{CompiledProgram, OriginalCompiledProgram};
1723

@@ -69,7 +75,7 @@ pub fn detect_empirical_constraints(
6975
// If this becomes a RAM issue, we can also pass individual segments to process_trace.
7076
// The advantage of the current approach is that the percentiles can be computed more accurately.
7177
tracing::info!(" Collecting trace...");
72-
let (trace, new_debug_info) = collect_trace(&program, input);
78+
let (trace, new_debug_info) = collect_trace(&program, input, degree_bound.identities);
7379
tracing::info!(" Detecting constraints...");
7480
constraint_detector.process_trace(trace, new_debug_info);
7581
}
@@ -78,12 +84,17 @@ pub fn detect_empirical_constraints(
7884
constraint_detector.finalize()
7985
}
8086

81-
fn collect_trace(program: &CompiledProgram, inputs: StdIn) -> (Trace, DebugInfo) {
87+
fn collect_trace(
88+
program: &CompiledProgram,
89+
inputs: StdIn,
90+
degree_bound: usize,
91+
) -> (Trace, DebugInfo) {
8292
let mut trace = Trace::default();
8393
let mut debug_info = DebugInfo::default();
8494
let mut seg_idx = 0;
8595

8696
do_with_trace(program, inputs, |vm, _pk, ctx| {
97+
let airs = program.vm_config.sdk.airs(degree_bound).unwrap();
8798
let global_airs = vm
8899
.config()
89100
.create_airs()
@@ -93,48 +104,61 @@ fn collect_trace(program: &CompiledProgram, inputs: StdIn) -> (Trace, DebugInfo)
93104
.collect::<HashMap<_, _>>();
94105

95106
for (air_id, proving_context) in &ctx.per_air {
96-
let air = &global_airs[air_id];
97-
let Some(column_names) = air.columns() else {
98-
// Instruction chips always have column names.
99-
continue;
100-
};
101-
102107
if !proving_context.cached_mains.is_empty() {
103108
// Instruction chips always have a cached main.
104109
continue;
105110
}
106111
let main = proving_context.common_main.as_ref().unwrap();
107-
assert_eq!(main.width, column_names.len());
108-
109-
// Instruction chips have a PC and time stamp
110-
let find_col = |name: &str| -> Option<usize> {
111-
column_names.iter().position(|col_name| {
112-
col_name == name || col_name == &format!("inner__{}", name)
112+
let air_name = global_airs[air_id].name();
113+
let (machine, _) = &airs.air_name_to_machine.get(&air_name).unwrap();
114+
115+
// Find the execution bus interation
116+
// This assumes there is exactly one, which is the case for instruction chips
117+
let execution_bus_interaction = machine
118+
.bus_interactions
119+
.iter()
120+
.find(|interaction| {
121+
interaction.id
122+
== default_openvm_bus_map()
123+
.get_bus_id(&BusType::ExecutionBridge)
124+
.unwrap()
113125
})
114-
};
115-
let Some(pc_index) = find_col("from_state__pc") else {
116-
continue;
117-
};
118-
let ts_index = find_col("from_state__timestamp").unwrap();
126+
.unwrap();
119127

120128
for row in main.row_slices() {
121-
let row = row.iter().map(|v| v.as_canonical_u32()).collect::<Vec<_>>();
122-
let pc_value = row[pc_index];
123-
let ts_value = row[ts_index];
129+
// Create an evaluator over this row
130+
let evaluator = RowEvaluator::new(row);
124131

125-
if pc_value == 0 {
126-
// Padding row!
132+
// Evaluate the execution bus interaction
133+
let execution = evaluator.eval_bus_interaction(execution_bus_interaction);
134+
135+
// `is_valid` is the multiplicity
136+
let is_valid = execution.mult;
137+
if is_valid == BabyBear::ZERO {
138+
// If `is_valid` is zero, this is a padding row
127139
continue;
128140
}
129141

142+
// Recover the values of the pc and timestamp
143+
let [pc, timestamp] = execution
144+
.args
145+
.map(|v| v.as_canonical_u32())
146+
.collect_vec()
147+
.try_into()
148+
.unwrap();
149+
150+
// Convert the row to u32s
151+
// TODO: is this necessary?
152+
let row = row.iter().map(|v| v.as_canonical_u32()).collect();
153+
130154
let row = Row {
131155
cells: row,
132-
pc: pc_value,
133-
timestamp: (seg_idx, ts_value),
156+
pc,
157+
timestamp: (seg_idx, timestamp),
134158
};
135159
trace.rows.push(row);
136160

137-
match debug_info.air_id_by_pc.entry(pc_value) {
161+
match debug_info.air_id_by_pc.entry(pc) {
138162
Entry::Vacant(entry) => {
139163
entry.insert(*air_id);
140164
}
@@ -143,9 +167,10 @@ fn collect_trace(program: &CompiledProgram, inputs: StdIn) -> (Trace, DebugInfo)
143167
}
144168
}
145169
if !debug_info.column_names_by_air_id.contains_key(air_id) {
146-
debug_info
147-
.column_names_by_air_id
148-
.insert(*air_id, column_names.clone());
170+
debug_info.column_names_by_air_id.insert(
171+
*air_id,
172+
machine.main_columns().map(|r| (*r.name).clone()).collect(),
173+
);
149174
}
150175
}
151176
}

0 commit comments

Comments
 (0)