Skip to content

Commit d40352e

Browse files
committed
Collect empirical constraints
1 parent bee4c27 commit d40352e

File tree

7 files changed

+355
-0
lines changed

7 files changed

+355
-0
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
use std::collections::BTreeMap;
2+
3+
use serde::{Deserialize, Serialize};
4+
5+
/// "Constraints" that were inferred from execution statistics.
6+
#[derive(Serialize, Deserialize, Clone, Default)]
7+
pub struct EmpiricalConstraints {
8+
/// For each program counter, the range constraints for each column.
9+
/// The range might not hold in 100% of cases.
10+
pub column_ranges_by_pc: BTreeMap<u32, Vec<(u32, u32)>>,
11+
/// For each basic block (identified by its starting PC), the equivalence classes of columns.
12+
/// Each equivalence class is a list of (instruction index in block, column index).
13+
pub equivalence_classes_by_block: BTreeMap<u64, Vec<Vec<(usize, usize)>>>,
14+
}
15+
16+
/// Debug information mapping AIR ids to program counters and column names.
17+
#[derive(Serialize, Deserialize)]
18+
pub struct DebugInfo {
19+
/// Mapping from program counter to AIR id.
20+
pub air_id_by_pc: BTreeMap<u32, usize>,
21+
/// Mapping from AIR id to column names.
22+
pub column_names_by_air_id: BTreeMap<usize, Vec<String>>,
23+
}
24+
25+
#[derive(Serialize, Deserialize)]
26+
pub struct EmpiricalConstraintsJson {
27+
pub empirical_constraints: EmpiricalConstraints,
28+
pub debug_info: DebugInfo,
29+
}

autoprecompiles/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ pub mod adapter;
2626
pub mod blocks;
2727
pub mod bus_map;
2828
pub mod constraint_optimizer;
29+
pub mod empirical_constraints;
2930
pub mod evaluation;
3031
pub mod execution_profile;
3132
pub mod expression;

cli-openvm/src/main.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use metrics_util::{debugging::DebuggingRecorder, layers::Layer};
44
use openvm_sdk::StdIn;
55
use openvm_stark_sdk::bench::serialize_metric_snapshot;
66
use powdr_autoprecompiles::pgo::{pgo_config, PgoType};
7+
use powdr_openvm::detect_empirical_constraints;
78
use powdr_openvm::{compile_openvm, default_powdr_openvm_config, CompiledProgram, GuestOptions};
89

910
#[cfg(feature = "metrics")]
@@ -144,6 +145,11 @@ fn run_command(command: Commands) {
144145
let execution_profile =
145146
powdr_openvm::execution_profile_from_guest(&guest_program, stdin_from(input));
146147

148+
let _empirical_constraints = detect_empirical_constraints(
149+
&guest_program,
150+
powdr_config.degree_bound,
151+
stdin_from(input),
152+
);
147153
let pgo_config = pgo_config(pgo, max_columns, execution_profile);
148154
let program =
149155
powdr_openvm::compile_exe(guest_program, powdr_config, pgo_config).unwrap();

openvm/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ derive_more = { version = "2.0.1", default-features = false, features = [
6161
"from",
6262
] }
6363
itertools = "0.14.0"
64+
serde_json = "1.0.140"
6465

6566
tracing = "0.1.40"
6667
tracing-subscriber = { version = "0.3.17", features = ["std", "env-filter"] }
Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
use itertools::Itertools;
2+
use openvm_circuit::arch::VmCircuitConfig;
3+
use openvm_sdk::StdIn;
4+
use openvm_stark_backend::p3_matrix::dense::DenseMatrix;
5+
use openvm_stark_sdk::openvm_stark_backend::p3_field::PrimeField32;
6+
use openvm_stark_sdk::p3_baby_bear::BabyBear;
7+
use powdr_autoprecompiles::blocks::BasicBlock;
8+
use powdr_autoprecompiles::empirical_constraints::{
9+
DebugInfo, EmpiricalConstraints, EmpiricalConstraintsJson,
10+
};
11+
use powdr_autoprecompiles::DegreeBound;
12+
use std::collections::hash_map::Entry;
13+
use std::collections::BTreeMap;
14+
use std::{collections::HashMap, sync::Arc};
15+
16+
use crate::trace_generation::do_with_trace;
17+
use crate::{CompiledProgram, Instr, OriginalCompiledProgram};
18+
19+
use std::collections::HashSet;
20+
use std::hash::Hash;
21+
22+
/// Materialized execution trace, Indexed by time and by PC
23+
#[derive(Default)]
24+
struct Trace {
25+
/// The raw rows, in any order
26+
rows: Vec<Vec<u32>>,
27+
/// Mapping (segment_idx, timestamp) -> row index in `rows`
28+
rows_by_time: BTreeMap<(usize, u32), usize>,
29+
/// PC value -> List of row indices in `rows` with that PC
30+
rows_by_pc: HashMap<u32, Vec<usize>>,
31+
}
32+
33+
pub fn detect_empirical_constraints(
34+
program: &OriginalCompiledProgram,
35+
degree_bound: DegreeBound,
36+
inputs: StdIn,
37+
) -> EmpiricalConstraints {
38+
let blocks = program.collect_basic_blocks(degree_bound.identities);
39+
40+
// Collect trace, without any autoprecompiles.
41+
let program = program.compiled_program(Vec::new(), degree_bound.identities);
42+
let (trace, debug_info) = collect_trace(&program, inputs);
43+
let empirical_constraints = generate_empirical_constraints(&blocks, trace);
44+
45+
// Export to disk
46+
let export = EmpiricalConstraintsJson {
47+
empirical_constraints: empirical_constraints.clone(),
48+
debug_info,
49+
};
50+
let json = serde_json::to_string_pretty(&export).unwrap();
51+
std::fs::write("empirical_constraints.json", json).unwrap();
52+
53+
empirical_constraints
54+
}
55+
56+
fn collect_trace(program: &CompiledProgram, inputs: StdIn) -> (Trace, DebugInfo) {
57+
let mut trace = Trace::default();
58+
let mut debug_info = DebugInfo {
59+
air_id_by_pc: BTreeMap::new(),
60+
column_names_by_air_id: BTreeMap::new(),
61+
};
62+
let mut seg_idx = 0;
63+
64+
do_with_trace(program, inputs, |vm, _pk, ctx| {
65+
let global_airs = vm
66+
.config()
67+
.create_airs()
68+
.unwrap()
69+
.into_airs()
70+
.enumerate()
71+
.collect::<HashMap<_, _>>();
72+
73+
for (air_id, proving_context) in &ctx.per_air {
74+
if !proving_context.cached_mains.is_empty() {
75+
// Not the case for instruction circuits
76+
continue;
77+
}
78+
let main: &Arc<DenseMatrix<BabyBear>> = proving_context.common_main.as_ref().unwrap();
79+
80+
let air = &global_airs[air_id];
81+
let Some(column_names) = air.columns() else {
82+
continue;
83+
};
84+
assert_eq!(main.width, column_names.len());
85+
86+
// This is the case for all instruction circuits
87+
let Some(pc_index) = column_names
88+
.iter()
89+
.position(|name| name == "from_state__pc")
90+
else {
91+
continue;
92+
};
93+
let ts_index = 1;
94+
95+
for row in main.row_slices() {
96+
let row = row.iter().map(|v| v.as_canonical_u32()).collect::<Vec<_>>();
97+
let pc_value = row[pc_index];
98+
let ts_value = row[ts_index];
99+
trace.rows.push(row);
100+
let row_index = trace.rows.len() - 1;
101+
trace.rows_by_time.insert((seg_idx, ts_value), row_index);
102+
103+
if pc_value == 0 {
104+
// Padding row!
105+
continue;
106+
}
107+
108+
match trace.rows_by_pc.entry(pc_value) {
109+
Entry::Vacant(e) => {
110+
// First time we see this PC, initialize the column -> values map
111+
e.insert(vec![row_index]);
112+
debug_info
113+
.column_names_by_air_id
114+
.insert(*air_id, column_names.clone());
115+
debug_info.air_id_by_pc.insert(pc_value, *air_id);
116+
}
117+
Entry::Occupied(mut o) => {
118+
let rows = o.get_mut();
119+
assert_eq!(
120+
debug_info.air_id_by_pc[&pc_value],
121+
*air_id,
122+
"Mismatched air IDs for PC {}: {} vs {}",
123+
pc_value,
124+
global_airs[&debug_info.air_id_by_pc[&pc_value]].name(),
125+
air.name()
126+
);
127+
rows.push(row_index);
128+
}
129+
}
130+
}
131+
}
132+
133+
seg_idx += 1;
134+
})
135+
.unwrap();
136+
(trace, debug_info)
137+
}
138+
139+
fn generate_empirical_constraints(
140+
blocks: &[BasicBlock<Instr<BabyBear>>],
141+
trace: Trace,
142+
) -> EmpiricalConstraints {
143+
// Block ID -> instruction count mapping
144+
let instruction_counts = blocks
145+
.iter()
146+
.map(|block| (block.start_pc, block.statements.len()))
147+
.collect::<HashMap<_, _>>();
148+
149+
// Block ID -> Vec<Vec<Row>>
150+
let mut block_rows = BTreeMap::new();
151+
let mut i = 0;
152+
let rows_by_time = trace.rows_by_time.values().collect::<Vec<_>>();
153+
while i < rows_by_time.len() {
154+
let row = &trace.rows[*rows_by_time[i]];
155+
let pc_value = row[0] as u64;
156+
157+
if instruction_counts.contains_key(&pc_value) {
158+
let instruction_count = instruction_counts[&pc_value];
159+
let block_row_slice = &rows_by_time[i..i + instruction_count];
160+
block_rows
161+
.entry(pc_value)
162+
.or_insert(Vec::new())
163+
.push(block_row_slice.to_vec());
164+
i += instruction_count;
165+
} else {
166+
i += 1;
167+
}
168+
}
169+
170+
// Block ID -> Vec<Vec<Vec<(instruction_index, col_index)>>>:
171+
// Indices: block ID, instance idx, equivalence class idx, cell
172+
let equivalence_classes = block_rows
173+
.into_iter()
174+
.map(|(block_id, blocks)| {
175+
let classes = blocks
176+
.into_iter()
177+
.map(|rows| {
178+
let value_to_cells = rows
179+
.into_iter()
180+
.enumerate()
181+
.flat_map(|(instruction_index, row_index)| {
182+
trace.rows[*row_index]
183+
.iter()
184+
.enumerate()
185+
.map(|(col_index, v)| (*v, (instruction_index, col_index)))
186+
.collect::<Vec<_>>()
187+
})
188+
.into_group_map();
189+
value_to_cells.values().cloned().collect::<Vec<_>>()
190+
})
191+
.collect::<Vec<_>>();
192+
(block_id, classes)
193+
})
194+
.collect::<HashMap<_, _>>();
195+
196+
// Intersect equivalence classes across all instances
197+
let intersected_equivalence_classes = equivalence_classes
198+
.into_iter()
199+
.map(|(block_id, classes)| {
200+
let intersected = intersect_partitions(&classes);
201+
202+
// Remove singleton classes
203+
let intersected = intersected
204+
.into_iter()
205+
.filter(|class| class.len() > 1)
206+
.collect::<Vec<_>>();
207+
208+
(block_id, intersected)
209+
})
210+
.collect::<BTreeMap<_, _>>();
211+
212+
// Map all column values to their range (1st and 99th percentile) for each pc
213+
let column_ranges_by_pc: HashMap<u32, Vec<(u32, u32)>> = trace
214+
.rows_by_pc
215+
.into_iter()
216+
.map(|(pc, pc_rows)| {
217+
let rows = pc_rows
218+
.into_iter()
219+
.map(|row_index| &trace.rows[row_index])
220+
.collect::<Vec<_>>();
221+
for row in &rows {
222+
// All rows for a given PC should be in the same chip
223+
assert_eq!(row.len(), rows[0].len());
224+
}
225+
let column_ranges = (0..rows[0].len())
226+
.map(|col_index| {
227+
let mut values = rows.iter().map(|row| row[col_index]).collect::<Vec<_>>();
228+
values.sort_unstable();
229+
let len = values.len();
230+
let p1_index = len / 100; // 1st percentile
231+
let p99_index = len * 99 / 100; // 99th percentile
232+
(values[p1_index], values[p99_index])
233+
})
234+
.collect();
235+
(pc, column_ranges)
236+
})
237+
.collect();
238+
239+
EmpiricalConstraints {
240+
column_ranges_by_pc: column_ranges_by_pc.into_iter().collect(),
241+
equivalence_classes_by_block: intersected_equivalence_classes,
242+
}
243+
}
244+
245+
// ChatGPT generated code
246+
fn intersect_partitions<Id>(partitions: &[Vec<Vec<Id>>]) -> Vec<Vec<Id>>
247+
where
248+
Id: Eq + Hash + Copy,
249+
{
250+
if partitions.is_empty() {
251+
return Vec::new();
252+
}
253+
254+
// 1) For each partition, build a map: Id -> class_index
255+
let mut maps: Vec<HashMap<Id, usize>> = Vec::with_capacity(partitions.len());
256+
for part in partitions {
257+
let mut m = HashMap::new();
258+
for (class_idx, class) in part.iter().enumerate() {
259+
for &id in class {
260+
m.insert(id, class_idx);
261+
}
262+
}
263+
maps.push(m);
264+
}
265+
266+
// 2) Collect the universe of all Ids
267+
let mut universe: HashSet<Id> = HashSet::new();
268+
for part in partitions {
269+
for class in part {
270+
for &id in class {
271+
universe.insert(id);
272+
}
273+
}
274+
}
275+
276+
// 3) For each Id, build its "signature" of class indices across all partitions
277+
// and group by that signature.
278+
let mut grouped: HashMap<Vec<usize>, Vec<Id>> = HashMap::new();
279+
280+
for &id in &universe {
281+
let mut signature = Vec::with_capacity(maps.len());
282+
let mut is_singleton = false;
283+
for m in &maps {
284+
let Some(class_idx) = m.get(&id) else {
285+
// The element did not appear in one of the partition, so it is its
286+
// own equivalence class. We can also omit it in the output partition.
287+
is_singleton = true;
288+
break;
289+
};
290+
signature.push(*class_idx);
291+
}
292+
if !is_singleton {
293+
grouped.entry(signature).or_default().push(id);
294+
}
295+
}
296+
297+
// 4) Resulting equivalence classes are the grouped values
298+
grouped.into_values().collect()
299+
}

openvm/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ use crate::powdr_extension::{PowdrExtensionExecutor, PowdrPrecompile};
6666
mod air_builder;
6767
pub mod bus_map;
6868
pub mod cuda_abi;
69+
mod empirical_constraints;
6970
pub mod extraction_utils;
7071
pub mod opcode;
7172
mod program;
@@ -76,6 +77,8 @@ pub use opcode::instruction_allowlist;
7677
pub use powdr_autoprecompiles::DegreeBound;
7778
pub use powdr_autoprecompiles::PgoConfig;
7879

80+
pub use crate::empirical_constraints::detect_empirical_constraints;
81+
7982
pub type BabyBearSC = BabyBearPoseidon2Config;
8083

8184
cfg_if::cfg_if! {

0 commit comments

Comments
 (0)