Skip to content

Commit 187778f

Browse files
authored
Replace computation method "InverseOrZero" by "QuotientOrZero". (#3457)
Replace the computation method "InverseOrZero" by "QuotientOrZero" which is more flexible.
1 parent 02d64e0 commit 187778f

File tree

6 files changed

+48
-32
lines changed

6 files changed

+48
-32
lines changed

autoprecompiles/src/optimizer.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -217,9 +217,10 @@ fn symbolic_machine_to_constraint_system<P: FieldElement>(
217217
.map(|(v, method)| {
218218
let method = match method {
219219
ComputationMethod::Constant(c) => ComputationMethod::Constant(*c),
220-
ComputationMethod::InverseOrZero(c) => {
221-
ComputationMethod::InverseOrZero(algebraic_to_grouped_expression(c))
222-
}
220+
ComputationMethod::QuotientOrZero(e1, e2) => ComputationMethod::QuotientOrZero(
221+
algebraic_to_grouped_expression(e1),
222+
algebraic_to_grouped_expression(e2),
223+
),
223224
};
224225
DerivedVariable {
225226
variable: v.clone(),
@@ -250,9 +251,10 @@ fn constraint_system_to_symbolic_machine<P: FieldElement>(
250251
.map(|derived_var| {
251252
let method = match derived_var.computation_method {
252253
ComputationMethod::Constant(c) => ComputationMethod::Constant(c),
253-
ComputationMethod::InverseOrZero(c) => {
254-
ComputationMethod::InverseOrZero(grouped_expression_to_algebraic(c))
255-
}
254+
ComputationMethod::QuotientOrZero(e1, e2) => ComputationMethod::QuotientOrZero(
255+
grouped_expression_to_algebraic(e1),
256+
grouped_expression_to_algebraic(e2),
257+
),
256258
};
257259
(derived_var.variable, method)
258260
})

autoprecompiles/src/symbolic_machine_generator.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@ pub fn convert_machine_field_type<T, U>(
3535
ComputationMethod::Constant(c) => {
3636
ComputationMethod::Constant(convert_field_element(c))
3737
}
38-
ComputationMethod::InverseOrZero(e) => ComputationMethod::InverseOrZero(
39-
convert_expression(e, convert_field_element),
38+
ComputationMethod::QuotientOrZero(e1, e2) => ComputationMethod::QuotientOrZero(
39+
convert_expression(e1, convert_field_element),
40+
convert_expression(e2, convert_field_element),
4041
),
4142
};
4243
(v, method)

constraint-solver/src/constraint_system.rs

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,15 +94,16 @@ pub struct DerivedVariable<T, V> {
9494
pub enum ComputationMethod<T, E> {
9595
/// A constant value.
9696
Constant(T),
97-
/// The field inverse of an expression if it exists or zero otherwise.
98-
InverseOrZero(E),
97+
/// The quotiont (using inversion in the field) of the first argument
98+
/// by the second argument, or zero if the latter is zero.
99+
QuotientOrZero(E, E),
99100
}
100101

101102
impl<T: Display, E: Display> Display for ComputationMethod<T, E> {
102103
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103104
match self {
104105
ComputationMethod::Constant(c) => write!(f, "{c}"),
105-
ComputationMethod::InverseOrZero(e) => write!(f, "InverseOrZero({e})"),
106+
ComputationMethod::QuotientOrZero(e1, e2) => write!(f, "QuotientOrZero({e1}, {e2})"),
106107
}
107108
}
108109
}
@@ -112,7 +113,10 @@ impl<T, F> ComputationMethod<T, GroupedExpression<T, F>> {
112113
pub fn referenced_unknown_variables(&self) -> Box<dyn Iterator<Item = &F> + '_> {
113114
match self {
114115
ComputationMethod::Constant(_) => Box::new(std::iter::empty()),
115-
ComputationMethod::InverseOrZero(e) => e.referenced_unknown_variables(),
116+
ComputationMethod::QuotientOrZero(e1, e2) => Box::new(
117+
e1.referenced_unknown_variables()
118+
.chain(e2.referenced_unknown_variables()),
119+
),
116120
}
117121
}
118122
}
@@ -125,8 +129,9 @@ impl<T: RuntimeConstant + Substitutable<V>, V: Ord + Clone + Eq>
125129
pub fn substitute_by_known(&mut self, variable: &V, substitution: &T) {
126130
match self {
127131
ComputationMethod::Constant(_) => {}
128-
ComputationMethod::InverseOrZero(e) => {
129-
e.substitute_by_known(variable, substitution);
132+
ComputationMethod::QuotientOrZero(e1, e2) => {
133+
e1.substitute_by_known(variable, substitution);
134+
e2.substitute_by_known(variable, substitution);
130135
}
131136
}
132137
}
@@ -138,8 +143,9 @@ impl<T: RuntimeConstant + Substitutable<V>, V: Ord + Clone + Eq>
138143
pub fn substitute_by_unknown(&mut self, variable: &V, substitution: &GroupedExpression<T, V>) {
139144
match self {
140145
ComputationMethod::Constant(_) => {}
141-
ComputationMethod::InverseOrZero(e) => {
142-
e.substitute_by_unknown(variable, substitution);
146+
ComputationMethod::QuotientOrZero(e1, e2) => {
147+
e1.substitute_by_unknown(variable, substitution);
148+
e2.substitute_by_unknown(variable, substitution);
143149
}
144150
}
145151
}

constraint-solver/src/indexed_constraint_system.rs

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -829,30 +829,32 @@ mod tests {
829829
derived_variables: vec![
830830
DerivedVariable {
831831
variable: "d1",
832-
computation_method: ComputationMethod::InverseOrZero(
833-
GroupedExpression::from_unknown_variable("x"),
832+
computation_method: ComputationMethod::QuotientOrZero(
833+
GroupedExpression::from_unknown_variable("x1"),
834+
GroupedExpression::from_unknown_variable("x2"),
834835
),
835836
},
836837
DerivedVariable {
837838
variable: "d2",
838-
computation_method: ComputationMethod::InverseOrZero(
839-
GroupedExpression::from_unknown_variable("y"),
839+
computation_method: ComputationMethod::QuotientOrZero(
840+
GroupedExpression::from_unknown_variable("y1"),
841+
GroupedExpression::from_unknown_variable("y2"),
840842
),
841843
},
842844
],
843845
}
844846
.into();
845-
// We first substitute `y` by an expression that contains `x` such that when we
846-
// substitute `x` in the next step, `d2` has to be updated again.
847+
// We first substitute `y2` by an expression that contains `x1` such that when we
848+
// substitute `x1` in the next step, `d2` has to be updated again.
847849
system.substitute_by_unknown(
848-
&"y",
849-
&(GroupedExpression::from_unknown_variable("x")
850+
&"y2",
851+
&(GroupedExpression::from_unknown_variable("x1")
850852
+ GroupedExpression::from_number(7.into())),
851853
);
852-
system.substitute_by_known(&"x", &1.into());
854+
system.substitute_by_known(&"x1", &1.into());
853855
assert_eq!(
854856
format!("{system}"),
855-
"d1 := InverseOrZero(1)\nd2 := InverseOrZero(8)"
857+
"d1 := QuotientOrZero(1, x2)\nd2 := QuotientOrZero(y1, 8)"
856858
);
857859
}
858860
}

openvm/src/powdr_extension/trace_generator/cpu/mod.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,16 +183,19 @@ impl PowdrTraceGeneratorCpu {
183183
let col_index = apc_poly_id_to_index[&column.id];
184184
row_slice[col_index] = match computation_method {
185185
ComputationMethod::Constant(c) => *c,
186-
ComputationMethod::InverseOrZero(expr) => {
186+
ComputationMethod::QuotientOrZero(e1, e2) => {
187187
use powdr_number::ExpressionConvertible;
188188

189-
let expr_val = expr.to_expression(&|n| *n, &|column_ref| {
189+
let divisor_val = e2.to_expression(&|n| *n, &|column_ref| {
190190
row_slice[apc_poly_id_to_index[&column_ref.id]]
191191
});
192-
if expr_val.is_zero() {
192+
if divisor_val.is_zero() {
193193
BabyBear::ZERO
194194
} else {
195-
expr_val.inverse()
195+
divisor_val.inverse()
196+
* e1.to_expression(&|n| *n, &|column_ref| {
197+
row_slice[apc_poly_id_to_index[&column_ref.id]]
198+
})
196199
}
197200
}
198201
};

openvm/src/powdr_extension/trace_generator/cuda/mod.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,12 @@ fn compile_derived_to_gpu(
120120
bytecode.push(OpCode::PushConst as u32);
121121
bytecode.push(c.as_canonical_u32());
122122
}
123-
ComputationMethod::InverseOrZero(expr) => {
123+
ComputationMethod::QuotientOrZero(e1, e2) => {
124124
// Encode inner expression, then apply InvOrZero
125-
emit_expr(&mut bytecode, expr, apc_poly_id_to_index, apc_height);
125+
emit_expr(&mut bytecode, e2, apc_poly_id_to_index, apc_height);
126126
bytecode.push(OpCode::InvOrZero as u32);
127+
emit_expr(&mut bytecode, e1, apc_poly_id_to_index, apc_height);
128+
bytecode.push(OpCode::Mul as u32);
127129
}
128130
}
129131
let len = (bytecode.len() as u32) - off;

0 commit comments

Comments
 (0)