Skip to content

Commit 515021d

Browse files
authored
Fix range analysis (#376)
1 parent 2a86b6e commit 515021d

File tree

2 files changed

+188
-13
lines changed

2 files changed

+188
-13
lines changed

src/enzyme_ad/jax/Passes/CanonicalizeLoops.cpp

Lines changed: 170 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,8 @@ class AffineIntegerRangeAnalysis
194194
/// affine.parallel (%i) = (0) to (10) {
195195
/// // getBoundsFromAffineParallel(op, 0) returns {0, 9}
196196
/// }
197-
std::pair<APInt, APInt>
198-
getBoundsFromAffineParallel(affine::AffineParallelOp loop, size_t idx) {
197+
ConstantIntRanges getBoundsFromAffineParallel(affine::AffineParallelOp loop,
198+
size_t idx) {
199199
SmallVector<AffineExpr> lbounds(
200200
loop.getLowerBoundsMap().getResults().begin(),
201201
loop.getLowerBoundsMap().getResults().end());
@@ -225,11 +225,12 @@ class AffineIntegerRangeAnalysis
225225

226226
if (lb && ub) {
227227
// Create APInt values with 64 bit.
228-
return {APInt(/*numBits=*/64, lb.getValue(), /*isSigned=*/true),
229-
APInt(/*numBits=*/64, ub.getValue() - 1, /*isSigned=*/true)};
228+
return ConstantIntRanges::fromSigned(
229+
APInt(/*numBits=*/64, lb.getValue(), /*isSigned=*/true),
230+
APInt(/*numBits=*/64, ub.getValue() - 1, /*isSigned=*/true));
230231
}
231232
// Return sentinel values if bounds cannot be determined
232-
return {APInt::getSignedMinValue(64), APInt::getSignedMaxValue(64)};
233+
return ConstantIntRanges::maxRange(64);
233234
}
234235
};
235236

@@ -279,9 +280,9 @@ void AffineIntegerRangeAnalysis::visitNonControlFlowArguments(
279280
// not expose all the necessary interfaces/methods.
280281
if (auto loop = dyn_cast<affine::AffineParallelOp>(op)) {
281282
for (Value iv : loop.getIVs()) {
282-
auto [min, max] = getBoundsFromAffineParallel(loop, 0);
283+
ConstantIntRanges ivRange = getBoundsFromAffineParallel(
284+
loop, cast<BlockArgument>(iv).getArgNumber());
283285
IntegerValueRangeLattice *ivEntry = getLatticeElement(iv);
284-
auto ivRange = ConstantIntRanges::fromSigned(min, max);
285286
propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange}));
286287
}
287288
return;
@@ -614,26 +615,184 @@ struct CanonicalizeLoopsPass
614615
range.getValue().getConstantValue();
615616
if (!constantRangeValue.has_value())
616617
return;
617-
if (constantRangeValue->eq(cstRhs)) {
618+
b.setInsertionPoint(cmpiOp);
619+
auto cst = b.create<arith::ConstantOp>(
620+
cmpiOp.getLoc(), b.getI1Type(),
621+
IntegerAttr::get(b.getI1Type(), !constantRangeValue->eq(cstRhs)));
622+
cmpiOp.getResult().replaceAllUsesWith(cst);
623+
}
624+
if (pred == arith::CmpIPredicate::eq) {
625+
std::optional<APInt> constantRangeValue =
626+
range.getValue().getConstantValue();
627+
if (!constantRangeValue.has_value())
628+
return;
629+
b.setInsertionPoint(cmpiOp);
630+
auto cst = b.create<arith::ConstantOp>(
631+
cmpiOp.getLoc(), b.getI1Type(),
632+
IntegerAttr::get(b.getI1Type(), constantRangeValue->eq(cstRhs)));
633+
cmpiOp.getResult().replaceAllUsesWith(cst);
634+
}
635+
if (pred == arith::CmpIPredicate::ult) {
636+
const APInt umax = cstRange.umax();
637+
const APInt umin = cstRange.umin();
638+
if (umax.ult(cstRhs)) {
639+
// Condition always true.
640+
b.setInsertionPoint(cmpiOp);
641+
auto cst = b.create<arith::ConstantOp>(
642+
cmpiOp.getLoc(), b.getI1Type(),
643+
IntegerAttr::get(b.getI1Type(), true));
644+
cmpiOp.getResult().replaceAllUsesWith(cst);
645+
}
646+
// range < cst -> !(range >= cst)
647+
if (umin.uge(cstRhs)) {
648+
// Condition always false.
618649
b.setInsertionPoint(cmpiOp);
619650
auto cst = b.create<arith::ConstantOp>(
620651
cmpiOp.getLoc(), b.getI1Type(),
621652
IntegerAttr::get(b.getI1Type(), false));
622653
cmpiOp.getResult().replaceAllUsesWith(cst);
623654
}
624655
}
625-
if (pred == arith::CmpIPredicate::ult) {
656+
if (pred == arith::CmpIPredicate::ule) {
657+
const APInt umax = cstRange.umax();
658+
const APInt umin = cstRange.umin();
659+
if (umax.ule(cstRhs)) {
660+
// Condition always true.
661+
b.setInsertionPoint(cmpiOp);
662+
auto cst = b.create<arith::ConstantOp>(
663+
cmpiOp.getLoc(), b.getI1Type(),
664+
IntegerAttr::get(b.getI1Type(), true));
665+
cmpiOp.getResult().replaceAllUsesWith(cst);
666+
}
667+
// range <= cst -> !(range > cst)
668+
if (umin.ugt(cstRhs)) {
669+
// Condition always false.
670+
b.setInsertionPoint(cmpiOp);
671+
auto cst = b.create<arith::ConstantOp>(
672+
cmpiOp.getLoc(), b.getI1Type(),
673+
IntegerAttr::get(b.getI1Type(), false));
674+
cmpiOp.getResult().replaceAllUsesWith(cst);
675+
}
676+
}
677+
if (pred == arith::CmpIPredicate::ugt) {
626678
const APInt umax = cstRange.umax();
627679
const APInt umin = cstRange.umin();
628-
if (umax.ult(cstRhs) && umin.ult(cstRhs)) {
680+
if (umax.ugt(cstRhs)) {
681+
// Condition always true.
682+
b.setInsertionPoint(cmpiOp);
683+
auto cst = b.create<arith::ConstantOp>(
684+
cmpiOp.getLoc(), b.getI1Type(),
685+
IntegerAttr::get(b.getI1Type(), true));
686+
cmpiOp.getResult().replaceAllUsesWith(cst);
687+
}
688+
// range > cst -> !(range <= cst)
689+
if (umin.ule(cstRhs)) {
690+
// Condition always false.
691+
b.setInsertionPoint(cmpiOp);
692+
auto cst = b.create<arith::ConstantOp>(
693+
cmpiOp.getLoc(), b.getI1Type(),
694+
IntegerAttr::get(b.getI1Type(), false));
695+
cmpiOp.getResult().replaceAllUsesWith(cst);
696+
}
697+
}
698+
if (pred == arith::CmpIPredicate::uge) {
699+
const APInt umax = cstRange.umax();
700+
const APInt umin = cstRange.umin();
701+
if (umax.uge(cstRhs)) {
702+
// Condition always true.
703+
b.setInsertionPoint(cmpiOp);
704+
auto cst = b.create<arith::ConstantOp>(
705+
cmpiOp.getLoc(), b.getI1Type(),
706+
IntegerAttr::get(b.getI1Type(), true));
707+
cmpiOp.getResult().replaceAllUsesWith(cst);
708+
}
709+
// range >= cst -> !(range < cst)
710+
if (umin.ult(cstRhs)) {
711+
// Condition always false.
712+
b.setInsertionPoint(cmpiOp);
713+
auto cst = b.create<arith::ConstantOp>(
714+
cmpiOp.getLoc(), b.getI1Type(),
715+
IntegerAttr::get(b.getI1Type(), false));
716+
cmpiOp.getResult().replaceAllUsesWith(cst);
717+
}
718+
}
719+
720+
if (pred == arith::CmpIPredicate::slt) {
721+
const APInt smax = cstRange.smax();
722+
const APInt smin = cstRange.smin();
723+
if (smax.slt(cstRhs)) {
724+
// Condition always true.
725+
b.setInsertionPoint(cmpiOp);
726+
auto cst = b.create<arith::ConstantOp>(
727+
cmpiOp.getLoc(), b.getI1Type(),
728+
IntegerAttr::get(b.getI1Type(), true));
729+
cmpiOp.getResult().replaceAllUsesWith(cst);
730+
}
731+
// range < cst -> !(range >= cst)
732+
if (smin.sge(cstRhs)) {
733+
// Condition always false.
734+
b.setInsertionPoint(cmpiOp);
735+
auto cst = b.create<arith::ConstantOp>(
736+
cmpiOp.getLoc(), b.getI1Type(),
737+
IntegerAttr::get(b.getI1Type(), false));
738+
cmpiOp.getResult().replaceAllUsesWith(cst);
739+
}
740+
}
741+
if (pred == arith::CmpIPredicate::sle) {
742+
const APInt smax = cstRange.smax();
743+
const APInt smin = cstRange.smin();
744+
if (smax.sle(cstRhs)) {
745+
// Condition always true.
746+
b.setInsertionPoint(cmpiOp);
747+
auto cst = b.create<arith::ConstantOp>(
748+
cmpiOp.getLoc(), b.getI1Type(),
749+
IntegerAttr::get(b.getI1Type(), true));
750+
cmpiOp.getResult().replaceAllUsesWith(cst);
751+
}
752+
// range <= cst -> !(range > cst)
753+
if (smin.sgt(cstRhs)) {
754+
// Condition always false.
755+
b.setInsertionPoint(cmpiOp);
756+
auto cst = b.create<arith::ConstantOp>(
757+
cmpiOp.getLoc(), b.getI1Type(),
758+
IntegerAttr::get(b.getI1Type(), false));
759+
cmpiOp.getResult().replaceAllUsesWith(cst);
760+
}
761+
}
762+
if (pred == arith::CmpIPredicate::sgt) {
763+
const APInt smax = cstRange.smax();
764+
const APInt smin = cstRange.smin();
765+
if (smax.sgt(cstRhs)) {
766+
// Condition always true.
767+
b.setInsertionPoint(cmpiOp);
768+
auto cst = b.create<arith::ConstantOp>(
769+
cmpiOp.getLoc(), b.getI1Type(),
770+
IntegerAttr::get(b.getI1Type(), true));
771+
cmpiOp.getResult().replaceAllUsesWith(cst);
772+
}
773+
// range > cst -> !(range <= cst)
774+
if (smin.sle(cstRhs)) {
775+
// Condition always false.
776+
b.setInsertionPoint(cmpiOp);
777+
auto cst = b.create<arith::ConstantOp>(
778+
cmpiOp.getLoc(), b.getI1Type(),
779+
IntegerAttr::get(b.getI1Type(), false));
780+
cmpiOp.getResult().replaceAllUsesWith(cst);
781+
}
782+
}
783+
if (pred == arith::CmpIPredicate::sge) {
784+
const APInt smax = cstRange.smax();
785+
const APInt smin = cstRange.smin();
786+
if (smax.sge(cstRhs)) {
629787
// Condition always true.
630788
b.setInsertionPoint(cmpiOp);
631789
auto cst = b.create<arith::ConstantOp>(
632790
cmpiOp.getLoc(), b.getI1Type(),
633791
IntegerAttr::get(b.getI1Type(), true));
634792
cmpiOp.getResult().replaceAllUsesWith(cst);
635793
}
636-
if (!umax.ult(cstRhs) && !umin.ult(cstRhs)) {
794+
// range >= cst -> !(range < cst)
795+
if (smin.slt(cstRhs)) {
637796
// Condition always false.
638797
b.setInsertionPoint(cmpiOp);
639798
auto cst = b.create<arith::ConstantOp>(

test/lit_tests/integer-analysis.mlir

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: enzymexlamlir-opt --pass-pipeline="builtin.module(func.func(canonicalize-loops), canonicalize)" --split-input-file %s | FileCheck %s
1+
// RUN: enzymexlamlir-opt --pass-pipeline="builtin.module(func.func(canonicalize-loops), canonicalize)" --split-input-file %s -allow-unregistered-dialect | FileCheck %s
22

33
// CHECK-LABEL: foo
44
// CHECK-SAME: %[[ARG0:.+]]: memref<256xi64>
@@ -170,4 +170,20 @@ func.func private @foo(%arg0: memref<1x134x374xf64, 1>) {
170170
// CHECK-NEXT: %[[V13:.+]] = affine.load %[[ARG0]][0, %[[ARG1]], 13] : memref<1x134x374xf64, 1>
171171
// CHECK-NEXT: affine.store %[[V13]], %[[ARG0]][0, %[[ARG1]], 373] : memref<1x134x374xf64, 1>
172172
// CHECK-NEXT: }
173-
// CHECK-NEXT: return
173+
// CHECK-NEXT: return
174+
175+
176+
func.func @argmatch() {
177+
%c25 = arith.constant 25 : index
178+
affine.parallel (%arg0, %arg1) = (0, 0) to (24, 256) {
179+
%0 = arith.cmpi ult, %arg0, %c25 : index
180+
%1 = arith.cmpi ult, %arg1, %c25 : index
181+
"test.operation"(%0, %1) : (i1, i1) -> ()
182+
}
183+
return
184+
}
185+
186+
// CHECK-LABEL: @argmatch
187+
// CHECK: affine.parallel (%[[ARG0:.+]], %[[ARG1:.+]]) = (0, 0) to (24, 256)
188+
// CHECK: %[[CMP:.+]] = arith.cmpi ult, %[[ARG1]], %c25
189+
// CHECK: "test.operation"(%true, %[[CMP]])

0 commit comments

Comments
 (0)