Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 214 additions & 15 deletions autoprecompiles/src/constraint_optimizer.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::{
collections::{HashMap, HashSet},
collections::{BTreeMap, BTreeSet, HashMap, HashSet},
fmt::Display,
hash::Hash,
iter::once,
Expand All @@ -11,9 +11,10 @@ use powdr_constraint_solver::{
constraint_system::{
AlgebraicConstraint, BusInteractionHandler, ConstraintRef, ConstraintSystem,
},
grouped_expression::GroupedExpression,
grouped_expression::{GroupedExpression, GroupedExpressionComponent},
indexed_constraint_system::IndexedConstraintSystem,
inliner::DegreeBound,
range_constraint::RangeConstraint,
reachability::reachable_variables,
solver::Solver,
};
Expand Down Expand Up @@ -75,6 +76,9 @@ pub fn optimize_constraints<
remove_disconnected_columns(constraint_system, solver, bus_interaction_handler.clone());
stats_logger.log("removing disconnected columns", &constraint_system);

let constraint_system = combine_free_variables(constraint_system, solver);
stats_logger.log("combining free variables", &constraint_system);

let constraint_system = trivial_simplifications(
constraint_system,
bus_interaction_handler.clone(),
Expand Down Expand Up @@ -206,22 +210,13 @@ fn remove_free_variables<T: FieldElement, V: Clone + Ord + Eq + Hash + Display>(
bus_interaction_handler: impl IsBusStateful<T> + Clone,
) -> IndexedConstraintSystem<T, V> {
let all_variables = constraint_system
.system()
.referenced_unknown_variables()
.cloned()
.collect::<HashSet<_>>();

let variables_to_delete = all_variables
.iter()
// Find variables that are referenced in exactly one constraint
.filter_map(|variable| {
constraint_system
.constraints_referencing_variables(once(variable))
.exactly_one()
.ok()
.map(|constraint| (variable.clone(), constraint))
})
.filter(|(variable, constraint)| match constraint {
// Find variables that are referenced in exactly one constraint
let variables_to_delete = single_occurrence_variables(&constraint_system)
.filter(|(constraint, variable)| match constraint {
// Remove the algebraic constraint if we can solve for the variable.
ConstraintRef::AlgebraicConstraint(constr) => {
can_always_be_satisfied_via_free_variable(*constr, variable)
Expand Down Expand Up @@ -253,7 +248,7 @@ fn remove_free_variables<T: FieldElement, V: Clone + Ord + Eq + Hash + Display>(
is_stateless && has_one_unknown_field && all_degrees_at_most_one
}
})
.map(|(variable, _constraint)| variable.clone())
.map(|(constraint, variable)| variable.clone())
.collect::<HashSet<_>>();

let variables_to_keep = all_variables
Expand All @@ -280,6 +275,23 @@ fn remove_free_variables<T: FieldElement, V: Clone + Ord + Eq + Hash + Display>(
constraint_system
}

/// Returns pairs of constraints and variables such that the variable occurs only
/// in the given constraint.
fn single_occurrence_variables<T: FieldElement, V: Clone + Ord + Eq + Hash + Display>(
constraint_system: &IndexedConstraintSystem<T, V>,
) -> impl Iterator<Item = (ConstraintRef<T, V>, V)> {
constraint_system
.referenced_unknown_variables()
.unique()
.filter_map(|variable| {
constraint_system
.constraints_referencing_variables(once(variable))
.exactly_one()
.ok()
.map(|constraint| (constraint, variable.clone()))
})
}

/// Returns true if the given constraint can always be made to be satisfied by setting the
/// free variable, regardless of the values of other variables.
fn can_always_be_satisfied_via_free_variable<
Expand All @@ -303,6 +315,193 @@ fn can_always_be_satisfied_via_free_variable<
}
}

/// Tries to combine multiple variables that only occur in the same algebraic
/// constraint.
///
/// The simplified pattern is `X * V1 + Y * V2 = C`, where `V1` and `V2` only occur
/// here and only once.
/// The only combination of values for `X`, `Y` and `C` where this is not satisfiable
/// is `X = 0`, `Y = 0`, `C != 0`. So the constraint is equivalent to the statement
/// `(X = 0 and Y = 0) -> C = 0`.
///
/// Considering the simpler case where both `X` and `Y` are non-negative such that
/// `X + Y` does not wrap.
/// Then `X = 0 and Y = 0` is equivalent to `X + Y = 0`. So we can replace the constraint
/// by `(X + Y) * V3 = C`, where `V3` is a new variable that only occurs here.
///
/// If e.g. `X` can be negative, we replace it by `X * X`, if that value is still small enough.
fn combine_free_variables<T: FieldElement, V: Clone + Ord + Eq + Hash + Display>(
mut constraint_system: IndexedConstraintSystem<T, V>,
solver: &mut impl Solver<T, V>,
) -> IndexedConstraintSystem<T, V> {
// TODO tracegen needs to be modified
let single_occurrence = single_occurrence_variables(&constraint_system)
.into_group_map()
.into_iter()
.flat_map(|(c, v)| match c {
ConstraintRef::AlgebraicConstraint(constr) => Some((constr.clone(), v)),
ConstraintRef::BusInteraction(_bus_interaction) => None,
})
.filter_map(|(c, vars)| {
// Keep only the variables that occur exactly once
// and then filter out the constraints where less than two
// variables are left.
let vars = vars
.iter()
.filter(|var| {
c.referenced_unknown_variables()
.filter(|v| v == var)
.count()
== 1
})
.cloned()
.collect_vec();
if vars.len() <= 1 {
None
} else {
Some((c, vars))
}
})
.flat_map(|(c, v)| {
FreeVariablePatternMatch::try_from_constraint(&c, v.iter().cloned().collect())
})
.collect_vec();
let mut constraints_to_add = vec![];
let mut vars_to_remove = BTreeSet::<V>::new();
for pattern_match in single_occurrence {
let mut grouped: GroupedExpression<T, V> = Zero::zero();
let mut grouped_rc = RangeConstraint::<T>::from_value(0.into());
let mut rest = pattern_match.rest;
let mut vars_replaced = BTreeSet::new();
// TODO could use fold here.
for (v, f, rc) in pattern_match
.variables
.into_iter()
.map(|(v, f)| {
let rc = f.range_constraint(solver);
(v, f, rc)
})
.sorted_by_key(|(_, _, rc)| rc.range_width())
{
// TODO we could actually also try to extract a factor from `f` to
// reduce its RC.
if rc.range().0 <= rc.range().1 {
// TODO not sure if this is correct.
// what we need ot check is that the only way for X + Y = 0
// is that X = 0 and Y = 0.
if rc.combine_sum(&grouped_rc).range_width() != T::modulus() {
grouped += f;
grouped_rc = rc.combine_sum(&grouped_rc);
vars_replaced.insert(v.clone());
} else {
rest += f * GroupedExpression::from_unknown_variable(v);
}
} else {
// TODO same here
if rc.square().combine_sum(&grouped_rc).range_width() != T::modulus() {
grouped += f.clone() * f;
grouped_rc = rc.square().combine_sum(&grouped_rc);
vars_replaced.insert(v.clone());
} else {
rest += f * GroupedExpression::from_unknown_variable(v);
}
}
}
if vars_replaced.len() >= 2 {
// TODO use a new variable
constraints_to_add.push(AlgebraicConstraint::assert_zero(
grouped
* GroupedExpression::from_unknown_variable(
vars_replaced.iter().next().unwrap().clone(),
)
+ rest,
));
vars_to_remove.extend(vars_replaced);
}
}

constraint_system.retain_algebraic_constraints(|constr| {
!constr
.referenced_unknown_variables()
.any(|v| vars_to_remove.contains(v))
});
constraint_system.add_algebraic_constraints(constraints_to_add);

constraint_system
}

/// This pattern match corresponds to the constraint
/// \sum_{(v, f) in variables} f * v + rest = 0
/// such that all variables are different single-occurrence variables.
struct FreeVariablePatternMatch<T, V> {
variables: BTreeMap<V, GroupedExpression<T, V>>,
rest: GroupedExpression<T, V>,
}

impl<T: FieldElement, V: Clone + Ord + Eq + Hash + Display> FreeVariablePatternMatch<T, V> {
fn try_from_constraint(
constraint: &AlgebraicConstraint<&GroupedExpression<T, V>>,
single_occurrence_variables: BTreeSet<V>,
) -> Option<Self> {
if single_occurrence_variables.len() < 2 {
return None;
}
let mut variables = BTreeMap::new();
let rest = constraint
.expression
.clone()
.into_summands()
.filter_map(|item| match item {
GroupedExpressionComponent::Constant(c) => Some(GroupedExpression::from_number(c)),
GroupedExpressionComponent::Linear(v, c) => {
Some(GroupedExpression::from_unknown_variable(v) * c)
}
GroupedExpressionComponent::Quadratic(l, r) => {
if let Some((c, v)) =
Self::try_expression_to_single_occurrence_variable_multiple(
&l,
&single_occurrence_variables,
)
{
variables.insert(v, r.clone() * c);
None
} else if let Some((c, v)) =
Self::try_expression_to_single_occurrence_variable_multiple(
&r,
&single_occurrence_variables,
)
{
variables.insert(v, l.clone() * c);
None
} else {
Some(l * r)
}
}
})
.sum();
if variables.len() >= 2 {
Some(Self { variables, rest })
} else {
None
}
}

/// If `expr` is of the form `c * v` where `v` is a single occurrence variable
/// and `c` a constant expression, returns `Some((c, v))`.
fn try_expression_to_single_occurrence_variable_multiple(
expr: &GroupedExpression<T, V>,
single_occurrence_variables: &BTreeSet<V>,
) -> Option<(T, V)> {
if expr.is_affine() && expr.constant_offset().is_zero() {
let (v, c) = expr.linear_components().exactly_one().ok()?;
if single_occurrence_variables.contains(v) {
return Some((*c, v.clone()));
}
}
None
}
}

/// Removes any columns that are not connected to *stateful* bus interactions (e.g. memory),
/// because those are the only way to interact with the rest of the zkVM (e.g. other
/// instructions).
Expand Down
1 change: 1 addition & 0 deletions constraint-solver/src/grouped_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ impl<T: RuntimeConstant, V: Ord + Clone + Eq> GroupedExpression<T, V> {
self.quadratic
.iter()
.map(|(l, r)| {
// TODO use rc.square() if there is an f such that l = r * f
l.range_constraint(range_constraints)
.combine_product(&r.range_constraint(range_constraints))
})
Expand Down
16 changes: 16 additions & 0 deletions constraint-solver/src/range_constraint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,22 @@ impl<T: FieldElement> RangeConstraint<T> {
}
}

pub fn square(&self) -> Self {
let square_rc = (self.min > self.max)
.then(|| {
let max_abs = std::cmp::max(-self.min, self.max);
if max_abs.to_arbitrary_integer() * max_abs.to_arbitrary_integer()
< T::modulus().to_arbitrary_integer()
{
Self::from_range(T::zero(), max_abs * max_abs)
} else {
Default::default()
}
})
.unwrap_or_default();
self.combine_product(self).conjunction(&square_rc)
}

/// Returns the conjunction of this constraint and the other.
/// This operation is not lossless, but if `r1` and `r2` allow
/// a value `x`, then `r1.conjunction(r2)` also allows `x`.
Expand Down
Loading
Loading