From 4755da6907951c7034310910227a817d3d607fec Mon Sep 17 00:00:00 2001 From: donald chen Date: Tue, 11 Feb 2025 10:07:26 +0000 Subject: [PATCH] [mlir] [DataFlow] Fix bug in int-range-analysis When querying the lower bound and upper bound of loop to update the value range of a loop iteration variable, the program point to depend on should be the block corresponding to the iteration variable rather than the loop operation. --- .../DataFlow/IntegerRangeAnalysis.cpp | 11 +++---- .../infer-int-range-test-ops.mlir | 30 +++++++++++++++++++ 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp index 9e9411e5ede12..722f4df18e981 100644 --- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp @@ -152,7 +152,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments( /// on a LoopLikeInterface return the lower/upper bound for that result if /// possible. auto getLoopBoundFromFold = [&](std::optional loopBound, - Type boundType, bool getUpper) { + Type boundType, Block *block, bool getUpper) { unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType); if (loopBound.has_value()) { if (auto attr = dyn_cast(*loopBound)) { @@ -160,7 +160,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments( return bound.getValue(); } else if (auto value = llvm::dyn_cast_if_present(*loopBound)) { const IntegerValueRangeLattice *lattice = - getLatticeElementFor(getProgramPointAfter(op), value); + getLatticeElementFor(getProgramPointBefore(block), value); if (lattice != nullptr && !lattice->getValue().isUninitialized()) return getUpper ? lattice->getValue().getValue().smax() : lattice->getValue().getValue().smin(); @@ -180,16 +180,17 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments( return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments( op, successor, argLattices, firstIndex); } + Block *block = iv->getParentBlock(); std::optional lowerBound = loop.getSingleLowerBound(); std::optional upperBound = loop.getSingleUpperBound(); std::optional step = loop.getSingleStep(); - APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), + APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), block, /*getUpper=*/false); - APInt max = getLoopBoundFromFold(upperBound, iv->getType(), + APInt max = getLoopBoundFromFold(upperBound, iv->getType(), block, /*getUpper=*/true); // Assume positivity for uniscoverable steps by way of getUpper = true. APInt stepVal = - getLoopBoundFromFold(step, iv->getType(), /*getUpper=*/true); + getLoopBoundFromFold(step, iv->getType(), block, /*getUpper=*/true); if (stepVal.isNegative()) { std::swap(min, max); diff --git a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir index 1ec3441b1fde8..b98e8b07db5ce 100644 --- a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir +++ b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir @@ -154,3 +154,33 @@ func.func @dont_propagate_across_infinite_loop() -> index { return %2 : index } +// CHECK-LABEL: @propagate_from_block_to_iterarg +func.func @propagate_from_block_to_iterarg(%arg0: index, %arg1: i1) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = scf.if %arg1 -> (index) { + %1 = scf.if %arg1 -> (index) { + scf.yield %arg0 : index + } else { + scf.yield %arg0 : index + } + scf.yield %1 : index + } else { + scf.yield %c1 : index + } + scf.for %arg2 = %c0 to %arg0 step %c1 { + scf.if %arg1 { + %1 = arith.subi %0, %c1 : index + %2 = arith.muli %0, %1 : index + %3 = arith.addi %2, %c1 : index + scf.for %arg3 = %c0 to %3 step %c1 { + %4 = arith.cmpi uge, %arg3, %c1 : index + // CHECK-NOT: scf.if %false + scf.if %4 { + "test.foo"() : () -> () + } + } + } + } + return +}