Skip to content

Commit 0a165fb

Browse files
committed
sca: only replace constant one when necessary
1 parent 5bca067 commit 0a165fb

File tree

1 file changed

+39
-15
lines changed

1 file changed

+39
-15
lines changed

patronus-sca/src/lib.rs

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,18 @@ pub fn verify_word_level_equality(ctx: &mut Context, p: ScaEqualityProblem) -> S
3636
word_poly.add_assign(&output_poly);
3737
let spec = word_poly;
3838

39+
// collect all (bit-level) input variables
40+
let input_vars: FxHashSet<VarIndex> = inputs
41+
.iter()
42+
.flat_map(|&e| {
43+
let width = e.get_bv_type(ctx).unwrap();
44+
let vars: Vec<_> = (0..width)
45+
.map(|bit| expr_to_var(extract_bit(ctx, e, bit)))
46+
.collect();
47+
vars
48+
})
49+
.collect();
50+
3951
// create todos for all output variables
4052
let gate_outputs: Vec<_> = (0..width)
4153
.map(|ii| {
@@ -49,7 +61,7 @@ pub fn verify_word_level_equality(ctx: &mut Context, p: ScaEqualityProblem) -> S
4961
.collect();
5062

5163
// now we can perform backwards substitution
52-
let result = backwards_sub(ctx, gate_outputs.into(), spec);
64+
let result = backwards_sub(ctx, &input_vars, gate_outputs.into(), spec);
5365

5466
if result.is_zero() {
5567
ScaVerifyResult::Equal
@@ -159,19 +171,18 @@ fn var_to_expr(v: polysub::VarIndex) -> ExprRef {
159171
usize::from(v).into()
160172
}
161173

162-
fn is_gate(expr: &Expr) -> bool {
163-
matches!(
164-
expr,
165-
Expr::BVNot(_, 1) | Expr::BVAnd(_, _, 1) | Expr::BVOr(_, _, 1) | Expr::BVXor(_, _, 1)
166-
)
167-
}
168-
169-
fn backwards_sub(ctx: &Context, mut todo: Vec<(VarIndex, ExprRef)>, mut spec: Poly) -> Poly {
174+
fn backwards_sub(
175+
ctx: &Context,
176+
input_vars: &FxHashSet<VarIndex>,
177+
mut todo: Vec<(VarIndex, ExprRef)>,
178+
mut spec: Poly,
179+
) -> Poly {
170180
let mut var_roots: Vec<_> = todo.iter().map(|(v, _)| *v).collect();
171181
var_roots.sort();
172182

173183
let m = spec.get_mod();
174184
let one: DefaultCoef = Coef::from_i64(1, m);
185+
let zero: DefaultCoef = Coef::from_i64(0, m);
175186
let minus_one: DefaultCoef = Coef::from_i64(-1, m);
176187
let minus_two: DefaultCoef = Coef::from_i64(-2, m);
177188
// first, we count how often expressions are used
@@ -183,7 +194,7 @@ fn backwards_sub(ctx: &Context, mut todo: Vec<(VarIndex, ExprRef)>, mut spec: Po
183194

184195
while let Some((output_var, gate)) = todo.pop() {
185196
replaced.push(output_var);
186-
// println!("{output_var}: {}", spec.size());
197+
println!("{output_var} {:?}: {}", &ctx[gate], spec.size());
187198

188199
let add_children = match ctx[gate].clone() {
189200
Expr::BVOr(a, b, 1) => {
@@ -232,15 +243,27 @@ fn backwards_sub(ctx: &Context, mut todo: Vec<(VarIndex, ExprRef)>, mut spec: Po
232243
spec.replace_var(output_var, &monoms);
233244
true
234245
}
235-
Expr::BVSlice { hi, lo, .. } => {
246+
Expr::BVSlice { hi, lo, e } => {
236247
assert_eq!(hi, lo);
248+
assert!(
249+
input_vars.contains(&expr_to_var(e)),
250+
"Not actually an input: {e:?}"
251+
);
237252
// a bit slice normally marks an input, thus we should be done!
238253
false
239254
}
255+
Expr::BVLiteral(value) => {
256+
let value = value.get(ctx);
257+
debug_assert_eq!(value.width(), 1);
258+
if value.is_true() {
259+
spec.replace_var(output_var, &[(one.clone(), vec![].into())]);
260+
} else {
261+
spec.replace_var(output_var, &[(zero.clone(), vec![].into())]);
262+
}
263+
false
264+
}
240265
other => todo!("add support for {other:?}"),
241266
};
242-
// get rid of any ones, TODO: we should eventually write better code and make this unnecessary
243-
spec.replace_var(const_true_var, &[(one.clone(), vec![].into())]);
244267

245268
if add_children {
246269
ctx[gate].for_each_child(|&e| {
@@ -249,8 +272,9 @@ fn backwards_sub(ctx: &Context, mut todo: Vec<(VarIndex, ExprRef)>, mut spec: Po
249272
assert!(prev_uses > 0);
250273
uses[e] = prev_uses - 1;
251274
// did the use count just go down to zero?
252-
if prev_uses == 1 && is_gate(&ctx[e]) {
253-
todo.push((expr_to_var(e), e));
275+
let var = expr_to_var(e);
276+
if prev_uses == 1 && !input_vars.contains(&var) {
277+
todo.push((var, e));
254278
}
255279
});
256280
}

0 commit comments

Comments
 (0)