Skip to content

Commit ad1edc9

Browse files
authored
[mlir][IntegerRangeAnalysis] Handle multi-dimensional loops (#170765)
Since LoopLikeInterface has (for some time) been extended to handle multiple induction variables (and thus lower and upper bounds), handle those bounds one at a time.
1 parent 5e4974f commit ad1edc9

File tree

2 files changed

+62
-44
lines changed

2 files changed

+62
-44
lines changed

mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp

Lines changed: 46 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -180,23 +180,20 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
180180
return;
181181
}
182182

183-
/// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep()
184-
/// on a LoopLikeInterface return the lower/upper bound for that result if
185-
/// possible.
186-
auto getLoopBoundFromFold = [&](std::optional<OpFoldResult> loopBound,
187-
Type boundType, Block *block, bool getUpper) {
183+
/// Given a lower bound, upper bound, or step from a LoopLikeInterface return
184+
/// the lower/upper bound for that result if possible.
185+
auto getLoopBoundFromFold = [&](OpFoldResult loopBound, Type boundType,
186+
Block *block, bool getUpper) {
188187
unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
189-
if (loopBound.has_value()) {
190-
if (auto attr = dyn_cast<Attribute>(*loopBound)) {
191-
if (auto bound = dyn_cast_or_null<IntegerAttr>(attr))
192-
return bound.getValue();
193-
} else if (auto value = llvm::dyn_cast_if_present<Value>(*loopBound)) {
194-
const IntegerValueRangeLattice *lattice =
195-
getLatticeElementFor(getProgramPointBefore(block), value);
196-
if (lattice != nullptr && !lattice->getValue().isUninitialized())
197-
return getUpper ? lattice->getValue().getValue().smax()
198-
: lattice->getValue().getValue().smin();
199-
}
188+
if (auto attr = dyn_cast<Attribute>(loopBound)) {
189+
if (auto bound = dyn_cast<IntegerAttr>(attr))
190+
return bound.getValue();
191+
} else if (auto value = llvm::dyn_cast<Value>(loopBound)) {
192+
const IntegerValueRangeLattice *lattice =
193+
getLatticeElementFor(getProgramPointBefore(block), value);
194+
if (lattice != nullptr && !lattice->getValue().isUninitialized())
195+
return getUpper ? lattice->getValue().getValue().smax()
196+
: lattice->getValue().getValue().smin();
200197
}
201198
// Given the results of getConstant{Lower,Upper}Bound()
202199
// or getConstantStep() on a LoopLikeInterface return the lower/upper
@@ -207,38 +204,43 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
207204

208205
// Infer bounds for loop arguments that have static bounds
209206
if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
210-
std::optional<Value> iv = loop.getSingleInductionVar();
211-
if (!iv) {
207+
std::optional<llvm::SmallVector<Value>> maybeIvs =
208+
loop.getLoopInductionVars();
209+
if (!maybeIvs) {
212210
return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments(
213211
op, successor, argLattices, firstIndex);
214212
}
215-
Block *block = iv->getParentBlock();
216-
std::optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
217-
std::optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
218-
std::optional<OpFoldResult> step = loop.getSingleStep();
219-
APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), block,
220-
/*getUpper=*/false);
221-
APInt max = getLoopBoundFromFold(upperBound, iv->getType(), block,
222-
/*getUpper=*/true);
223-
// Assume positivity for uniscoverable steps by way of getUpper = true.
224-
APInt stepVal =
225-
getLoopBoundFromFold(step, iv->getType(), block, /*getUpper=*/true);
226-
227-
if (stepVal.isNegative()) {
228-
std::swap(min, max);
229-
} else {
230-
// Correct the upper bound by subtracting 1 so that it becomes a <=
231-
// bound, because loops do not generally include their upper bound.
232-
max -= 1;
233-
}
213+
// This shouldn't be returning nullopt if there are indunction variables.
214+
SmallVector<OpFoldResult> lowerBounds = *loop.getLoopLowerBounds();
215+
SmallVector<OpFoldResult> upperBounds = *loop.getLoopUpperBounds();
216+
SmallVector<OpFoldResult> steps = *loop.getLoopSteps();
217+
for (auto [iv, lowerBound, upperBound, step] :
218+
llvm::zip_equal(*maybeIvs, lowerBounds, upperBounds, steps)) {
219+
Block *block = iv.getParentBlock();
220+
APInt min = getLoopBoundFromFold(lowerBound, iv.getType(), block,
221+
/*getUpper=*/false);
222+
APInt max = getLoopBoundFromFold(upperBound, iv.getType(), block,
223+
/*getUpper=*/true);
224+
// Assume positivity for uniscoverable steps by way of getUpper = true.
225+
APInt stepVal =
226+
getLoopBoundFromFold(step, iv.getType(), block, /*getUpper=*/true);
227+
228+
if (stepVal.isNegative()) {
229+
std::swap(min, max);
230+
} else {
231+
// Correct the upper bound by subtracting 1 so that it becomes a <=
232+
// bound, because loops do not generally include their upper bound.
233+
max -= 1;
234+
}
234235

235-
// If we infer the lower bound to be larger than the upper bound, the
236-
// resulting range is meaningless and should not be used in further
237-
// inferences.
238-
if (max.sge(min)) {
239-
IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv);
240-
auto ivRange = ConstantIntRanges::fromSigned(min, max);
241-
propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange}));
236+
// If we infer the lower bound to be larger than the upper bound, the
237+
// resulting range is meaningless and should not be used in further
238+
// inferences.
239+
if (max.sge(min)) {
240+
IntegerValueRangeLattice *ivEntry = getLatticeElement(iv);
241+
auto ivRange = ConstantIntRanges::fromSigned(min, max);
242+
propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange}));
243+
}
242244
}
243245
return;
244246
}

mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,19 @@ func.func @propagate_from_block_to_iterarg(%arg0: index, %arg1: i1) {
184184
}
185185
return
186186
}
187+
188+
// CHECK-LABEL: func @multiple_loop_ivs
189+
func.func @multiple_loop_ivs(%arg0: memref<?x64xi32>) {
190+
%ub1 = test.with_bounds { umin = 1 : index, umax = 32 : index,
191+
smin = 1 : index, smax = 32 : index } : index
192+
%c0_i32 = arith.constant 0 : i32
193+
// CHECK: scf.forall
194+
scf.forall (%arg1, %arg2) in (%ub1, 64) {
195+
// CHECK: test.reflect_bounds {smax = 31 : index, smin = 0 : index, umax = 31 : index, umin = 0 : index}
196+
%1 = test.reflect_bounds %arg1 : index
197+
// CHECK-NEXT: test.reflect_bounds {smax = 63 : index, smin = 0 : index, umax = 63 : index, umin = 0 : index}
198+
%2 = test.reflect_bounds %arg2 : index
199+
memref.store %c0_i32, %arg0[%1, %2] : memref<?x64xi32>
200+
}
201+
return
202+
}

0 commit comments

Comments
 (0)