Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 103 additions & 27 deletions autoprecompiles/src/empirical_constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,30 @@ use std::hash::Hash;
use itertools::Itertools;
use serde::{Deserialize, Serialize};

pub use crate::empirical_constraints::equivalence_class::EquivalenceClass;

/// "Constraints" that were inferred from execution statistics. They hold empirically
/// (most of the time), but are not guaranteed to hold in all cases.
#[derive(Serialize, Deserialize, Clone, Default, Debug)]
#[derive(Serialize, Default, Debug)]
pub struct EmpiricalConstraints {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merge DebugInfo into EmpiricalConstraints to simplify return types

/// For each program counter, the range constraints for each column.
/// The range might not hold in 100% of cases.
pub column_ranges_by_pc: BTreeMap<u32, Vec<(u32, u32)>>,
/// For each basic block (identified by its starting PC), the equivalence classes of columns.
/// Each equivalence class is a list of (instruction index in block, column index).
pub equivalence_classes_by_block: BTreeMap<u64, BTreeSet<BTreeSet<(usize, usize)>>>,
pub equivalence_classes_by_block: BTreeMap<u64, EquivalenceClasses<BlockCell>>,
pub debug_info: DebugInfo,
}

/// Debug information mapping AIR ids to program counters and column names.
#[derive(Serialize, Deserialize, Default)]
#[derive(Serialize, Deserialize, Default, Debug)]
pub struct DebugInfo {
/// Mapping from program counter to AIR id.
pub air_id_by_pc: BTreeMap<u32, usize>,
/// Mapping from AIR id to column names.
pub column_names_by_air_id: BTreeMap<usize, Vec<String>>,
}

#[derive(Serialize, Deserialize)]
pub struct EmpiricalConstraintsJson {
pub empirical_constraints: EmpiricalConstraints,
pub debug_info: DebugInfo,
}

impl EmpiricalConstraints {
pub fn combine_with(&mut self, other: EmpiricalConstraints) {
// Combine column ranges by PC
Expand All @@ -52,15 +49,14 @@ impl EmpiricalConstraints {

// Combine equivalence classes by block
for (block_pc, classes) in other.equivalence_classes_by_block {
self.equivalence_classes_by_block
let existing = self
.equivalence_classes_by_block
.entry(block_pc)
.and_modify(|existing_classes| {
let combined =
intersect_partitions(&[existing_classes.clone(), classes.clone()]);
*existing_classes = combined;
})
.or_insert(classes);
.or_default();

*existing = intersect_partitions(vec![std::mem::take(existing), classes]);
}
self.debug_info.combine_with(other.debug_info);
}
}

Expand Down Expand Up @@ -88,19 +84,97 @@ fn merge_maps<K: Ord, V: Eq + Debug>(map1: &mut BTreeMap<K, V>, map2: BTreeMap<K
}
}

#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, Debug, Copy, Clone)]
pub struct BlockCell {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Introduce type for the equivalence class members, aka cells of a block execution

/// The row index which is also the instruction index within the basic block
row_idx: usize,
/// The column index within the instruction air
column_idx: usize,
}

impl BlockCell {
pub fn new(row_idx: usize, column_idx: usize) -> Self {
Self {
row_idx,
column_idx,
}
}
}

mod equivalence_class {
use std::collections::BTreeSet;

use serde::Serialize;

/// An equivalence class with the following guarantees
/// - It cannot be empty
/// - It cannot hold a single element
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of filtering out equivalence classes, keep them out by construction

#[derive(Serialize, Debug, PartialOrd, Ord, PartialEq, Eq)]
pub struct EquivalenceClass<T> {
inner: BTreeSet<T>,
}

impl<T> Default for EquivalenceClass<T> {
fn default() -> Self {
Self {
inner: BTreeSet::default(),
}
}
}

impl<T> EquivalenceClass<T> {
pub fn iter(&self) -> impl Iterator<Item = &T> {
self.inner.iter()
}
}

impl<T: Ord> FromIterator<T> for EquivalenceClass<T> {
fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
let inner: BTreeSet<_> = iter.into_iter().collect();
if inner.len() > 1 {
Self { inner }
} else {
Self::default()
}
}
}
}

/// A collection of equivalence classes
#[derive(Serialize, Debug, PartialEq, Eq)]
pub struct EquivalenceClasses<T> {
inner: BTreeSet<EquivalenceClass<T>>,
}

impl<T> Default for EquivalenceClasses<T> {
fn default() -> Self {
Self {
inner: Default::default(),
}
}
}

impl<T: Ord> FromIterator<EquivalenceClass<T>> for EquivalenceClasses<T> {
fn from_iter<I: IntoIterator<Item = EquivalenceClass<T>>>(iter: I) -> Self {
Self {
inner: iter.into_iter().collect(),
}
}
}

/// Intersects multiple partitions of the same universe into a single partition.
/// In other words, two elements are in the same equivalence class in the resulting partition
/// if and only if they are in the same equivalence class in all input partitions.
/// Singleton equivalence classes are omitted from the result.
pub fn intersect_partitions<Id>(partitions: &[BTreeSet<BTreeSet<Id>>]) -> BTreeSet<BTreeSet<Id>>
where
Id: Eq + Hash + Copy + Ord,
{
pub fn intersect_partitions<T: Eq + Hash + Copy + Ord>(
partitions: Vec<EquivalenceClasses<T>>,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move

) -> EquivalenceClasses<T> {
// For each partition, build a map: Id -> class_index
let class_ids: Vec<HashMap<Id, usize>> = partitions
let class_ids: Vec<HashMap<T, usize>> = partitions
.iter()
.map(|partition| {
partition
.inner
.iter()
.enumerate()
.flat_map(|(class_idx, class)| class.iter().map(move |&id| (id, class_idx)))
Expand All @@ -109,9 +183,9 @@ where
.collect();

// Iterate over all elements in the universe
partitions
let res = partitions
.iter()
.flat_map(|partition| partition.iter())
.flat_map(|partition| &partition.inner)
.flat_map(|class| class.iter().copied())
.unique()
.filter_map(|id| {
Expand All @@ -130,14 +204,16 @@ where
.into_values()
// Remove singletons and convert to Set
.filter_map(|ids| (ids.len() > 1).then_some(ids.into_iter().collect()))
.collect()
.collect();

EquivalenceClasses { inner: res }
}

#[cfg(test)]
mod tests {
use std::collections::BTreeSet;
use crate::empirical_constraints::EquivalenceClasses;

fn partition(sets: Vec<Vec<u32>>) -> BTreeSet<BTreeSet<u32>> {
fn partition(sets: Vec<Vec<u32>>) -> EquivalenceClasses<u32> {
sets.into_iter().map(|s| s.into_iter().collect()).collect()
}

Expand All @@ -156,7 +232,7 @@ mod tests {
vec![6, 7, 8],
]);

let result = super::intersect_partitions(&[partition1, partition2]);
let result = super::intersect_partitions(vec![partition1, partition2]);

let expected = partition(vec![vec![2, 3], vec![6, 7, 8]]);

Expand Down
9 changes: 2 additions & 7 deletions cli-openvm/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use metrics_tracing_context::{MetricsLayer, TracingContextLayer};
use metrics_util::{debugging::DebuggingRecorder, layers::Layer};
use openvm_sdk::StdIn;
use openvm_stark_sdk::bench::serialize_metric_snapshot;
use powdr_autoprecompiles::empirical_constraints::EmpiricalConstraintsJson;
use powdr_autoprecompiles::pgo::{pgo_config, PgoType};
use powdr_autoprecompiles::PowdrConfig;
use powdr_openvm::{compile_openvm, default_powdr_openvm_config, CompiledProgram, GuestOptions};
Expand Down Expand Up @@ -311,19 +310,15 @@ fn maybe_compute_empirical_constraints(
"Optimistic precompiles are not implemented yet. Computing empirical constraints..."
);

let (empirical_constraints, debug_info) =
let empirical_constraints =
detect_empirical_constraints(guest_program, powdr_config.degree_bound, vec![stdin]);

if let Some(path) = &powdr_config.apc_candidates_dir_path {
tracing::info!(
"Saving empirical constraints debug info to {}/empirical_constraints.json",
path.display()
);
let export = EmpiricalConstraintsJson {
empirical_constraints: empirical_constraints.clone(),
debug_info,
};
let json = serde_json::to_string_pretty(&export).unwrap();
let json = serde_json::to_string_pretty(&empirical_constraints).unwrap();
std::fs::write(path.join("empirical_constraints.json"), json).unwrap();
}
}
Loading
Loading