Skip to content

Commit a83380a

Browse files
committed
[mlir][IntRangeAnalysis] Handle unstructured loop arguments correctly
The integer range analysis currently has a bug where, because of how it interacts with dead code analysis, it will sometimes declare code dead that isn't dead, becaues it hasn't seen the edge that loops an incremented value back to itself yet. This commit fixes the issue by overriding the join method on lattice values in order to detect these back-edges on non-entry blocks and then snapping the passed-around value to its maximum possible range, just like we do for loop-varying values in region control flow. Fixes #119045
1 parent 2dba66b commit a83380a

File tree

3 files changed

+116
-0
lines changed

3 files changed

+116
-0
lines changed

mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,16 @@ class IntegerValueRangeLattice : public Lattice<IntegerValueRange> {
3131
public:
3232
using Lattice::Lattice;
3333

34+
/// Override the join logic so that arguments to non-entry blocks
35+
/// whose arguments come from later in the program get set to
36+
/// a maximal value so that we don't prematurely declare code to be
37+
/// deade.
38+
ChangeResult join(const AbstractSparseLattice &rhs) override;
39+
40+
ChangeResult join(const IntegerValueRange &range) {
41+
return Lattice::join(range);
42+
}
43+
3444
/// If the range can be narrowed to an integer constant, update the constant
3545
/// value of the SSA value.
3646
void onUpdate(DataFlowSolver *solver) const override;

mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,50 @@
3737
using namespace mlir;
3838
using namespace mlir::dataflow;
3939

40+
/// Return true if `block` is a non-entry block with a predecessor that's
41+
/// defined after the block. This allows us to detect loop-varying values
42+
/// in unstructured control flow.
43+
static bool isLoopLikeBlock(Block *block) {
44+
if (!block || block->isEntryBlock())
45+
return false;
46+
Region *parent = block->getParent();
47+
if (!parent)
48+
return false;
49+
50+
SmallPtrSet<Block *, 4> preds;
51+
for (Block *pred : block->getPredecessors())
52+
preds.insert(pred);
53+
if (preds.size() <= 1)
54+
return false;
55+
56+
for (Block &regionBlock : parent->getBlocks()) {
57+
if (&regionBlock == block)
58+
break;
59+
preds.erase(&regionBlock);
60+
}
61+
62+
// The block loops back on itself or has an edge from further in the program.
63+
return !preds.empty();
64+
}
65+
66+
ChangeResult IntegerValueRangeLattice::join(const AbstractSparseLattice &rhs) {
67+
Value lhsAnchor = getAnchor();
68+
Block *lhsBlock = lhsAnchor.getParentBlock();
69+
unsigned width = ConstantIntRanges::getStorageBitwidth(lhsAnchor.getType());
70+
/// Special-case: we're in unstructured control flow and one of the
71+
/// predecessors of this block argument is defined in a block that comes after
72+
/// the argument. So we conservatively conclude that the value could be
73+
/// anything.
74+
if (width > 0 && isa<BlockArgument>(lhsAnchor) && isLoopLikeBlock(lhsBlock)) {
75+
LLVM_DEBUG(llvm::dbgs() << "Found loop-varying block argument " << lhsAnchor
76+
<< " from " << rhs.getAnchor() << "\n");
77+
LLVM_DEBUG(llvm::dbgs() << "Inferring maximum range\n");
78+
IntegerValueRange maxRange = IntegerValueRange::getMaxRange(lhsAnchor);
79+
return join(maxRange);
80+
}
81+
return Lattice::join(rhs);
82+
}
83+
4084
void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
4185
Lattice::onUpdate(solver);
4286

@@ -206,6 +250,8 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
206250
if (max.sge(min)) {
207251
IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv);
208252
auto ivRange = ConstantIntRanges::fromSigned(min, max);
253+
LLVM_DEBUG(llvm::dbgs()
254+
<< "Inferred loop bound range: " << ivRange << "\n");
209255
propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange}));
210256
}
211257
return;

mlir/test/Dialect/Arith/int-range-opts.mlir

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,63 @@ func.func @wraps() -> i8 {
132132
%mod = arith.remsi %val, %c64 : i8
133133
return %mod : i8
134134
}
135+
136+
// -----
137+
138+
// Note: I wish I had a simpler example than this, but getting rid of a
139+
// bunch of the arithmetic made the issue go away.
140+
// CHECK-LABEL: @blocks_prematurely_declared_dead_bug
141+
// CHECK-NOT: arith.constant true
142+
func.func @blocks_prematurely_declared_dead_bug(%mem: memref<?xf16>) {
143+
%cst = arith.constant dense<false> : vector<1xi1>
144+
%c1 = arith.constant 1 : index
145+
%cst_0 = arith.constant dense<0.000000e+00> : vector<1xf16>
146+
%cst_1 = arith.constant 0.000000e+00 : f16
147+
%c16 = arith.constant 16 : index
148+
%c0 = arith.constant 0 : index
149+
%c64 = arith.constant 64 : index
150+
%thread_id_x = gpu.thread_id x upper_bound 64
151+
%6 = test.with_bounds { smin = 16 : index, smax = 112 : index, umin = 16 : index, umax = 112 : index } : index
152+
%8 = arith.divui %6, %c16 : index
153+
%9 = arith.muli %8, %c16 : index
154+
cf.br ^bb1(%c0 : index)
155+
^bb1(%12: index): // 2 preds: ^bb0, ^bb7
156+
%13 = arith.cmpi slt, %12, %9 : index
157+
cf.cond_br %13, ^bb2, ^bb8
158+
^bb2: // pred: ^bb1
159+
%14 = arith.subi %9, %12 : index
160+
%15 = arith.minsi %14, %c64 : index
161+
%16 = arith.subi %15, %thread_id_x : index
162+
%17 = vector.constant_mask [1] : vector<1xi1>
163+
%18 = arith.cmpi sgt, %16, %c0 : index
164+
%19 = arith.select %18, %17, %cst : vector<1xi1>
165+
%20 = vector.extract %19[0] : i1 from vector<1xi1>
166+
%21 = vector.insert %20, %cst [0] : i1 into vector<1xi1>
167+
%22 = arith.addi %12, %thread_id_x : index
168+
cf.br ^bb3(%c0, %cst_0 : index, vector<1xf16>)
169+
^bb3(%23: index, %24: vector<1xf16>): // 2 preds: ^bb2, ^bb6
170+
%25 = arith.cmpi slt, %23, %c1 : index
171+
cf.cond_br %25, ^bb4, ^bb7
172+
^bb4: // pred: ^bb3
173+
%26 = vector.extractelement %21[%23 : index] : vector<1xi1>
174+
cf.cond_br %26, ^bb5, ^bb6(%24 : vector<1xf16>)
175+
^bb5: // pred: ^bb4
176+
%27 = arith.addi %22, %23 : index
177+
%28 = memref.load %mem[%27] : memref<?xf16>
178+
%29 = vector.insertelement %28, %24[%23 : index] : vector<1xf16>
179+
cf.br ^bb6(%29 : vector<1xf16>)
180+
^bb6(%30: vector<1xf16>): // 2 preds: ^bb4, ^bb5
181+
%31 = arith.addi %23, %c1 : index
182+
cf.br ^bb3(%31, %30 : index, vector<1xf16>)
183+
^bb7: // pred: ^bb3
184+
%37 = arith.addi %12, %c64 : index
185+
cf.br ^bb1(%37 : index)
186+
^bb8: // pred: ^bb1
187+
%70 = arith.cmpi eq, %thread_id_x, %c0 : index
188+
cf.cond_br %70, ^bb9, ^bb10
189+
^bb9: // pred: ^bb8
190+
memref.store %cst_1, %mem[%c0] : memref<?xf16>
191+
cf.br ^bb10
192+
^bb10: // 2 preds: ^bb8, ^bb9
193+
return
194+
}

0 commit comments

Comments
 (0)