Skip to content

Commit 2cbd4fe

Browse files
Fix IntervalAnalysis arith.constant issue with i1 values (#284)
* Fix SMT expression error with arith.constant i1s * Add changelog * Revert unneeded formatting change * Revert "Revert unneeded formatting change" This reverts commit c87c7c7. * Fix annoying diff between clang-format versions * Update lib/Analysis/IntervalAnalysis.cpp Co-authored-by: Timothy Hoffman <4001421+tim-hoffman@users.noreply.github.com> --------- Co-authored-by: Timothy Hoffman <4001421+tim-hoffman@users.noreply.github.com>
1 parent 358b000 commit 2cbd4fe

File tree

6 files changed

+75
-22
lines changed

6 files changed

+75
-22
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
fixed:
2+
- Fixed issue with interval analysis in creating SMT expressions for arith.constant operations

include/llzk/Analysis/AnalysisWrappers.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,10 @@ class ModuleAnalysis {
221221
ensure(res.succeeded(), "solver failed to run on module!");
222222

223223
const Context &ctx = getContext();
224+
// Force construction of empty results here so `getCurrentResults()` on
225+
// a module with no inner structs returns no results rather than an assertion
226+
// failure.
227+
results[ctx] = {};
224228
modOp.walk([this, &am, &ctx](component::StructDefOp s) mutable {
225229
auto &childAnalysis = am.getChildAnalysis<StructAnalysisTy>(s);
226230
// Don't re-run the analysis if we already have the results.

include/llzk/Analysis/IntervalAnalysis.h

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -307,11 +307,20 @@ class IntervalDataFlowAnalysis
307307

308308
llvm::SMTExprRef createFeltSymbol(const char *name) const;
309309

310-
bool isConstOp(mlir::Operation *op) const {
310+
inline bool isConstOp(mlir::Operation *op) const {
311311
return llvm::isa<
312312
felt::FeltConstantOp, mlir::arith::ConstantIndexOp, mlir::arith::ConstantIntOp>(op);
313313
}
314314

315+
inline bool isBoolConstOp(mlir::Operation *op) const {
316+
if (auto constIntOp = llvm::dyn_cast<mlir::arith::ConstantIntOp>(op)) {
317+
auto valAttr = dyn_cast<mlir::IntegerAttr>(constIntOp.getValue());
318+
ensure(valAttr != nullptr, "arith::ConstantIntOp must have an IntegerAttr as its value");
319+
return valAttr.getValue().getBitWidth() == 1;
320+
}
321+
return false;
322+
}
323+
315324
llvm::DynamicAPInt getConst(mlir::Operation *op) const;
316325

317326
inline llvm::SMTExprRef createConstBitvectorExpr(const llvm::DynamicAPInt &v) const {
@@ -322,9 +331,7 @@ class IntervalDataFlowAnalysis
322331
return smtSolver->mkBitvector(v, field.get().bitWidth());
323332
}
324333

325-
llvm::SMTExprRef createConstBoolExpr(bool v) const {
326-
return smtSolver->mkBitvector(mlir::APSInt((int)v), field.get().bitWidth());
327-
}
334+
llvm::SMTExprRef createConstBoolExpr(bool v) const { return smtSolver->mkBoolean(v); }
328335

329336
bool isArithmeticOp(mlir::Operation *op) const {
330337
return llvm::isa<

lib/Analysis/IntervalAnalysis.cpp

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,13 @@ mlir::LogicalResult IntervalDataFlowAnalysis::visitOperation(
463463
// Now, the way we update is dependent on the type of the operation.
464464
if (isConstOp(op)) {
465465
llvm::DynamicAPInt constVal = getConst(op);
466-
llvm::SMTExprRef expr = createConstBitvectorExpr(constVal);
466+
llvm::SMTExprRef expr;
467+
if (isBoolConstOp(op)) {
468+
expr = createConstBoolExpr(constVal != 0);
469+
} else {
470+
expr = createConstBitvectorExpr(constVal);
471+
}
472+
467473
ExpressionValue latticeVal(field.get(), expr, constVal);
468474
propagateIfChanged(results[0], results[0]->setValue(latticeVal));
469475
} else if (isArithmeticOp(op)) {
@@ -619,23 +625,28 @@ llvm::SMTExprRef IntervalDataFlowAnalysis::createFeltSymbol(const char *name) co
619625
llvm::DynamicAPInt IntervalDataFlowAnalysis::getConst(Operation *op) const {
620626
ensure(isConstOp(op), "op is not a const op");
621627

622-
llvm::DynamicAPInt fieldConst =
623-
TypeSwitch<Operation *, llvm::DynamicAPInt>(op)
624-
.Case<FeltConstantOp>([&](FeltConstantOp feltConst) {
625-
llvm::APSInt constOpVal(feltConst.getValue());
626-
return field.get().reduce(constOpVal);
627-
})
628-
.Case<arith::ConstantIndexOp>([&](arith::ConstantIndexOp indexConst) {
629-
return DynamicAPInt(indexConst.value());
630-
})
631-
.Case<arith::ConstantIntOp>([&](arith::ConstantIntOp intConst) {
632-
return DynamicAPInt(intConst.value());
633-
}).Default([](Operation *illegalOp) {
634-
std::string err;
635-
debug::Appender(err) << "unhandled getConst case: " << *illegalOp;
636-
llvm::report_fatal_error(Twine(err));
637-
return llvm::DynamicAPInt();
638-
});
628+
// NOTE: I think clang-format makes these hard to read by default
629+
// clang-format off
630+
llvm::DynamicAPInt fieldConst = TypeSwitch<Operation *, llvm::DynamicAPInt>(op)
631+
.Case<FeltConstantOp>([&](FeltConstantOp feltConst) {
632+
llvm::APSInt constOpVal(feltConst.getValue());
633+
return field.get().reduce(constOpVal);
634+
})
635+
.Case<arith::ConstantIndexOp>([&](arith::ConstantIndexOp indexConst) {
636+
return DynamicAPInt(indexConst.value());
637+
})
638+
.Case<arith::ConstantIntOp>([&](arith::ConstantIntOp intConst) {
639+
auto valAttr = dyn_cast<IntegerAttr>(intConst.getValue());
640+
ensure(valAttr != nullptr, "arith::ConstantIntOp must have an IntegerAttr as its value");
641+
return toDynamicAPInt(valAttr.getValue());
642+
})
643+
.Default([](Operation *illegalOp) {
644+
std::string err;
645+
debug::Appender(err) << "unhandled getConst case: " << *illegalOp;
646+
llvm::report_fatal_error(Twine(err));
647+
return llvm::DynamicAPInt();
648+
});
649+
// clang-format on
639650
return fieldConst;
640651
}
641652

test/Analysis/interval_analysis/interval_analysis_pass.llzk

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,3 +670,8 @@ module attributes {veridise.lang = "llzk"} {
670670
// CHECK-NEXT: %arg0[@y] in TypeA:[ 0, 255 ]
671671
// CHECK-NEXT: %arg0[@z] in TypeF:[ 21888242871839275222246405745257275088696311157297823662689037894645226208328, 255 ]
672672
// CHECK-NEXT: }
673+
674+
// -----
675+
676+
// COM: Regression test to avoid assertion failure on empty modules
677+
module attributes {veridise.lang = "llzk"} { }

test/Analysis/interval_analysis/interval_analysis_pass_compute.llzk

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,30 @@ module attributes {veridise.lang = "llzk"} {
604604

605605
// -----
606606

607+
module attributes {veridise.lang = "llzk"} {
608+
struct.def @ConstBools {
609+
struct.field @out : i1
610+
function.def @compute() -> !struct.type<@ConstBools> attributes {function.allow_witness} {
611+
%self = struct.new : !struct.type<@ConstBools>
612+
%true = arith.constant true
613+
%false = arith.constant false
614+
%res = bool.and %true, %false
615+
struct.writef %self[@out] = %res : !struct.type<@ConstBools>, i1
616+
function.return %self : !struct.type<@ConstBools>
617+
}
618+
function.def @constrain(%arg0: !struct.type<@ConstBools>) attributes {function.allow_constraint} {
619+
function.return
620+
}
621+
}
622+
}
623+
624+
// CHECK-LABEL: @ConstBools StructIntervals {
625+
// CHECK-NEXT: compute {
626+
// CHECK-NEXT: %self[@out] in Degenerate(0)
627+
// CHECK-NEXT: }
628+
629+
// -----
630+
607631
module attributes {veridise.lang = "llzk"} {
608632
struct.def @ArithConstraint {
609633

0 commit comments

Comments
 (0)