Skip to content

Commit 11c6c4d

Browse files
committed
various improvements and fixes from the sca branch
1 parent b2460a9 commit 11c6c4d

File tree

16 files changed

+302
-164
lines changed

16 files changed

+302
-164
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ authors = ["Kevin Laeufer <laeufer@berkeley.edu>"]
88
repository = "https://github.com/cucapra/patronus"
99
readme = "Readme.md"
1010
license = "BSD-3-Clause"
11-
rust-version = "1.85.0"
11+
rust-version = "1.88.0"
1212
# used by main patronus library and python bindings
1313
version = "0.35.0"
1414

patronus-egraphs/src/arithmetic.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,10 +282,10 @@ fn convert_bin_op(
282282
width_out,
283283
width_a,
284284
sign_a,
285-
converted_b,
285+
converted_a,
286286
width_b,
287287
sign_b,
288-
converted_a,
288+
converted_b,
289289
]))
290290
}
291291

patronus/src/expr.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ mod transform;
1414
pub mod traversal;
1515
mod types;
1616

17-
pub use analysis::{UseCountInt, count_expr_uses, update_expr_child_uses};
17+
pub use analysis::{UseCountInt, count_expr_uses, find_symbols, update_expr_child_uses};
1818
pub use context::{Builder, Context, ExprRef, StringRef};
1919
pub use eval::{SymbolValueStore, eval_array_expr, eval_bv_expr, eval_expr};
2020
pub use foreach::ForEachChild;

patronus/src/expr/analysis.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
// released under BSD 3-Clause License
44
// author: Kevin Laeufer <laeufer@cornell.edu>
55

6-
use crate::expr::{Context, ExprMap, ExprRef, ForEachChild, SparseExprMap};
6+
use crate::expr::{Context, ExprMap, ExprRef, ForEachChild, SparseExprMap, traversal};
7+
use rustc_hash::FxHashSet;
78

89
pub type UseCountInt = u16;
910

@@ -24,6 +25,17 @@ pub fn count_expr_uses(ctx: &Context, roots: Vec<ExprRef>) -> impl ExprMap<UseCo
2425
use_count
2526
}
2627

28+
/// Returns all symbols in the given expression.
29+
pub fn find_symbols(ctx: &Context, e: ExprRef) -> FxHashSet<ExprRef> {
30+
let mut out = FxHashSet::default();
31+
traversal::bottom_up(ctx, e, |ctx, e, _| {
32+
if ctx[e].is_symbol() {
33+
out.insert(e);
34+
}
35+
});
36+
out
37+
}
38+
2739
/// Increments the use counts for all children of the expression `expr` and
2840
/// adds any child encountered for the first time to the `todo` list.
2941
#[inline]

patronus/src/expr/context.rs

Lines changed: 48 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
use crate::expr::TypeCheck;
1818
use crate::expr::nodes::*;
1919
use baa::{
20-
ArrayOps, BitVecValue, BitVecValueIndex, BitVecValueRef, IndexToRef, SparseArrayValue, Value,
20+
ArrayOps, BitVecOps, BitVecValue, BitVecValueIndex, BitVecValueRef, IndexToRef,
21+
SparseArrayValue, Value,
2122
};
2223
use rustc_hash::FxBuildHasher;
2324
use std::borrow::Borrow;
@@ -53,18 +54,20 @@ pub struct ExprRef(NonZeroU32);
5354
impl Debug for ExprRef {
5455
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
5556
// we need a custom implementation in order to show the zero based index
56-
write!(f, "ExprRef({})", self.index())
57+
let index: usize = (*self).into();
58+
write!(f, "ExprRef({})", index)
5759
}
5860
}
5961

60-
impl ExprRef {
61-
// TODO: reduce visibility to pub(crate)
62-
pub fn from_index(index: usize) -> Self {
63-
ExprRef(NonZeroU32::new((index + 1) as u32).unwrap())
62+
impl From<ExprRef> for usize {
63+
fn from(value: ExprRef) -> Self {
64+
(value.0.get() - 1) as usize
6465
}
66+
}
6567

66-
pub(crate) fn index(&self) -> usize {
67-
(self.0.get() - 1) as usize
68+
impl From<usize> for ExprRef {
69+
fn from(index: usize) -> Self {
70+
ExprRef(NonZeroU32::new((index + 1) as u32).unwrap())
6871
}
6972
}
7073

@@ -87,8 +90,8 @@ impl Default for Context {
8790
strings: Default::default(),
8891
exprs: Default::default(),
8992
values: Default::default(),
90-
true_expr_ref: ExprRef::from_index(0),
91-
false_expr_ref: ExprRef::from_index(0),
93+
true_expr_ref: 0.into(), // only a placeholder!
94+
false_expr_ref: 0.into(), // only a placeholder!
9295
};
9396
// create valid cached expressions
9497
out.false_expr_ref = out.zero(1);
@@ -105,7 +108,7 @@ impl Context {
105108

106109
pub(crate) fn add_expr(&mut self, value: Expr) -> ExprRef {
107110
let (index, _) = self.exprs.insert_full(value);
108-
ExprRef::from_index(index)
111+
index.into()
109112
}
110113

111114
pub fn string(&mut self, value: std::borrow::Cow<str>) -> StringRef {
@@ -127,7 +130,7 @@ impl Index<ExprRef> for Context {
127130

128131
fn index(&self, index: ExprRef) -> &Self::Output {
129132
self.exprs
130-
.get_index(index.index())
133+
.get_index(index.into())
131134
.expect("Invalid ExprRef!")
132135
}
133136
}
@@ -142,21 +145,15 @@ impl Index<StringRef> for Context {
142145
}
143146
}
144147

148+
/// Convenience methods to inspect IR nodes.
145149
impl Context {
146-
/// Returns the number of interned expressions in this context.
147-
pub fn num_exprs(&self) -> usize {
148-
self.exprs.len()
149-
}
150-
151-
/// Returns a reference to the expression for the given reference.
152-
/// Panics if the reference is invalid (use indices in range 0..num_exprs()).
153-
pub fn get_expr(&self, r: ExprRef) -> &Expr {
154-
&self[r]
155-
}
156-
157-
/// Returns the zero-based intern index of the given expression reference.
158-
pub fn expr_index(&self, r: ExprRef) -> usize {
159-
r.index()
150+
/// Returns whether `e` represents a bit vector literal `0` of any width.
151+
pub fn is_zero(&self, e: ExprRef) -> bool {
152+
if let Expr::BVLiteral(value) = self[e] {
153+
value.get(self).is_zero()
154+
} else {
155+
false
156+
}
160157
}
161158
}
162159

@@ -247,6 +244,11 @@ impl Context {
247244
pub fn ones(&mut self, width: WidthInt) -> ExprRef {
248245
self.bv_lit(&BitVecValue::ones(width))
249246
}
247+
248+
pub fn distinct(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
249+
let is_eq = self.equal(a, b);
250+
self.not(is_eq)
251+
}
250252
pub fn equal(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
251253
debug_assert_eq!(a.get_type(self), b.get_type(self));
252254
if a.get_type(self).is_bit_vector() {
@@ -311,6 +313,20 @@ impl Context {
311313
debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
312314
self.add_expr(Expr::BVXor(a, b, b.get_bv_type(self).unwrap()))
313315
}
316+
317+
pub fn xor3(&mut self, a: ExprRef, b: ExprRef, c: ExprRef) -> ExprRef {
318+
let x = self.xor(a, b);
319+
self.xor(x, c)
320+
}
321+
322+
pub fn majority(&mut self, a: ExprRef, b: ExprRef, c: ExprRef) -> ExprRef {
323+
let a_and_b = self.and(a, b);
324+
let a_and_c = self.and(a, c);
325+
let b_and_c = self.and(b, c);
326+
let x = self.or(a_and_b, a_and_c);
327+
self.or(x, b_and_c)
328+
}
329+
314330
pub fn shift_left(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
315331
debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
316332
self.add_expr(Expr::BVShiftLeft(a, b, b.get_bv_type(self).unwrap()))
@@ -515,6 +531,12 @@ impl<'a> Builder<'a> {
515531
pub fn xor(&self, a: ExprRef, b: ExprRef) -> ExprRef {
516532
self.ctx.borrow_mut().xor(a, b)
517533
}
534+
pub fn xor3(&mut self, a: ExprRef, b: ExprRef, c: ExprRef) -> ExprRef {
535+
self.ctx.borrow_mut().xor3(a, b, c)
536+
}
537+
pub fn majority(&mut self, a: ExprRef, b: ExprRef, c: ExprRef) -> ExprRef {
538+
self.ctx.borrow_mut().majority(a, b, c)
539+
}
518540
pub fn shift_left(&self, a: ExprRef, b: ExprRef) -> ExprRef {
519541
self.ctx.borrow_mut().shift_left(a, b)
520542
}

patronus/src/expr/meta.rs

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -109,17 +109,19 @@ impl<T: Default + Clone + Debug> Index<ExprRef> for DenseExprMetaData<T> {
109109

110110
#[inline]
111111
fn index(&self, e: ExprRef) -> &Self::Output {
112-
self.inner.get(e.index()).unwrap_or(&self.default)
112+
let index: usize = e.into();
113+
self.inner.get(index).unwrap_or(&self.default)
113114
}
114115
}
115116

116117
impl<T: Default + Clone + Debug> IndexMut<ExprRef> for DenseExprMetaData<T> {
117118
#[inline]
118119
fn index_mut(&mut self, e: ExprRef) -> &mut Self::Output {
119-
if self.inner.len() <= e.index() {
120-
self.inner.resize(e.index() + 1, T::default());
120+
let index: usize = e.into();
121+
if self.inner.len() <= index {
122+
self.inner.resize(index + 1, T::default());
121123
}
122-
&mut self.inner[e.index()]
124+
&mut self.inner[index]
123125
}
124126
}
125127

@@ -155,7 +157,7 @@ impl<'a, T> Iterator for ExprMetaDataIter<'a, T> {
155157
match self.inner.next() {
156158
None => None,
157159
Some(value) => {
158-
let index_ref = ExprRef::from_index(self.index);
160+
let index_ref = self.index.into();
159161
self.index += 1;
160162
Some((index_ref, value))
161163
}
@@ -206,7 +208,7 @@ impl ExprSet for DenseExprSet {
206208

207209
#[inline]
208210
fn index_to_word_and_bit(index: ExprRef) -> (usize, u32) {
209-
let index = index.index();
211+
let index: usize = index.into();
210212
let word = index / Word::BITS as usize;
211213
let bit = index % Word::BITS as usize;
212214
(word, bit as u32)
@@ -239,9 +241,9 @@ mod tests {
239241
#[test]
240242
fn test_get_fixed_point() {
241243
let mut m = DenseExprMetaData::default();
242-
let zero = ExprRef::from_index(0);
243-
let one = ExprRef::from_index(1);
244-
let two = ExprRef::from_index(2);
244+
let zero: ExprRef = 0usize.into();
245+
let one: ExprRef = 1usize.into();
246+
let two: ExprRef = 2usize.into();
245247
m[zero] = Some(one);
246248
m[one] = Some(two);
247249
m[two] = Some(two);
@@ -257,8 +259,9 @@ mod tests {
257259
#[test]
258260
fn test_dense_bool() {
259261
let mut m = DenseExprSet::default();
260-
assert!(!m.contains(&ExprRef::from_index(7)));
261-
m.insert(ExprRef::from_index(7));
262-
assert!(m.contains(&ExprRef::from_index(7)));
262+
let expr_ref_7: ExprRef = 7usize.into();
263+
assert!(!m.contains(&expr_ref_7));
264+
m.insert(expr_ref_7);
265+
assert!(m.contains(&expr_ref_7));
263266
}
264267
}

patronus/src/expr/simplify.rs

Lines changed: 76 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
// Copyright 2023 The Regents of the University of California
2-
// Copyright 2024 Cornell University
2+
// Copyright 2024-2026 Cornell University
33
// released under BSD 3-Clause License
44
// author: Kevin Laeufer <laeufer@cornell.edu>
55

66
use super::{
7-
BVLitValue, Context, Expr, ExprMap, ExprRef, SparseExprMap, TypeCheck, WidthInt,
8-
do_transform_expr,
7+
BVLitValue, Context, Expr, ExprMap, ExprRef, SerializableIrNode, SparseExprMap, TypeCheck,
8+
WidthInt, do_transform_expr, find_symbols,
99
};
1010
use crate::expr::meta::get_fixed_point;
1111
use crate::expr::transform::ExprTransformMode;
12+
use crate::smt::{CheckSatResponse, SolverContext};
1213
use baa::BitVecOps;
1314
use smallvec::{SmallVec, smallvec};
1415

@@ -38,6 +39,68 @@ impl<T: ExprMap<Option<ExprRef>>> Simplifier<T> {
3839
);
3940
get_fixed_point(&mut self.cache, e).unwrap()
4041
}
42+
43+
/// Uses an SMT solver to check all simplification steps that have been made with this Simplifier.
44+
/// Returns the number of incorrect simplifications.
45+
pub fn verify_simplification(
46+
&self,
47+
ctx: &mut Context,
48+
solver: &mut impl SolverContext,
49+
) -> crate::smt::Result<usize> {
50+
let mut incorrect = 0;
51+
let mut correct = 0;
52+
for (key, &value) in self.cache.iter() {
53+
if let Some(simplified) = value {
54+
let key_symbols = find_symbols(ctx, key);
55+
let simpl_symbols = find_symbols(ctx, simplified);
56+
let symbols: Vec<_> = key_symbols.union(&simpl_symbols).cloned().collect();
57+
solver.push()?;
58+
for &sym in symbols.iter() {
59+
solver.declare_const(ctx, sym)?;
60+
}
61+
let not_eq = ctx.distinct(key, simplified);
62+
solver.assert(ctx, not_eq)?;
63+
match solver.check_sat()? {
64+
CheckSatResponse::Sat => {
65+
let key_value = solver.get_value(ctx, key)?;
66+
let simplified_value = solver.get_value(ctx, simplified)?;
67+
println!(
68+
"{} ({}) =/= ({}) {}",
69+
key.serialize_to_str(ctx),
70+
key_value.serialize_to_str(ctx),
71+
simplified_value.serialize_to_str(ctx),
72+
simplified.serialize_to_str(ctx)
73+
);
74+
let mut syms = vec![];
75+
for &sym in symbols.iter() {
76+
let value = solver.get_value(ctx, sym)?;
77+
syms.push(format!(
78+
"{}={}",
79+
sym.serialize_to_str(ctx),
80+
value.serialize_to_str(ctx)
81+
));
82+
}
83+
println!(" w/ {}", syms.join(", "));
84+
incorrect += 1;
85+
}
86+
CheckSatResponse::Unsat => {
87+
correct += 1;
88+
} // OK
89+
CheckSatResponse::Unknown => {} // OK
90+
}
91+
92+
solver.pop()?;
93+
}
94+
}
95+
if incorrect > 0 {
96+
println!(
97+
"{incorrect} / {} simplifications were incorrect. See log.",
98+
incorrect + correct
99+
);
100+
}
101+
102+
Ok(incorrect)
103+
}
41104
}
42105

43106
/// Simplifies one expression (not its children)
@@ -256,8 +319,11 @@ fn simplify_bv_and(ctx: &mut Context, a: ExprRef, b: ExprRef) -> Option<ExprRef>
256319
// a & !a -> 0
257320
(Expr::BVNot(inner, w), _) if *inner == b => Some(ctx.zero(*w)),
258321
(_, Expr::BVNot(inner, w)) if *inner == a => Some(ctx.zero(*w)),
259-
// !a & !b -> a | b
260-
(Expr::BVNot(a, _), Expr::BVNot(b, _)) => Some(ctx.or(*a, *b)),
322+
// !a & !b -> !(a | b)
323+
(Expr::BVNot(a, _), Expr::BVNot(b, _)) => {
324+
let or = ctx.or(*a, *b);
325+
Some(ctx.not(or))
326+
}
261327
_ => None,
262328
}
263329
}
@@ -292,8 +358,11 @@ fn simplify_bv_or(ctx: &mut Context, a: ExprRef, b: ExprRef) -> Option<ExprRef>
292358
// a | !a -> 1
293359
(Expr::BVNot(inner, w), _) if *inner == b => Some(ctx.ones(*w)),
294360
(_, Expr::BVNot(inner, w)) if *inner == a => Some(ctx.ones(*w)),
295-
// !a | !b -> a & b
296-
(Expr::BVNot(a, _), Expr::BVNot(b, _)) => Some(ctx.and(*a, *b)),
361+
// !a | !b -> !(a & b)
362+
(Expr::BVNot(a, _), Expr::BVNot(b, _)) => {
363+
let and = ctx.and(*a, *b);
364+
Some(ctx.not(and))
365+
}
297366
_ => None,
298367
}
299368
}

0 commit comments

Comments
 (0)