Skip to content

Commit 0b30c4e

Browse files
use ValueBoundsOpInterface.
1 parent e865351 commit 0b30c4e

File tree

3 files changed

+54
-100
lines changed

3 files changed

+54
-100
lines changed

mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp

Lines changed: 39 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,14 @@
1212

1313
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
1414

15-
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
16-
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
1715
#include "mlir/Analysis/SliceAnalysis.h"
1816
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
1917
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
2018
#include "mlir/Dialect/Affine/Analysis/NestedMatcher.h"
2119
#include "mlir/Dialect/Affine/IR/AffineOps.h"
2220
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
2321
#include "mlir/Interfaces/FunctionInterfaces.h"
22+
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
2423
#include "llvm/Support/MathExtras.h"
2524

2625
#include "llvm/ADT/DenseSet.h"
@@ -33,7 +32,6 @@
3332

3433
using namespace mlir;
3534
using namespace mlir::affine;
36-
using namespace mlir::dataflow;
3735

3836
#define DEBUG_TYPE "affine-loop-analysis"
3937

@@ -88,69 +86,37 @@ void mlir::affine::getTripCountMapAndOperands(
8886
tripCountValueMap.getOperands().end());
8987
}
9088

91-
/// By running `IntegerRangeAnalysis` to get the ranges of operand, then fill
92-
/// the `symReplacements` with range. If `replaceByMin` is set to true,
93-
/// construct `replacement` using the smallest value.By default, the largest
94-
/// value will be used for constructing `replacement`.
95-
static void replaceOperandByRange(AffineForOp forOp,
96-
SmallVectorImpl<Value> &operands,
97-
SmallVectorImpl<AffineExpr> &symReplacements,
98-
unsigned numDim, bool replaceByMin = false) {
99-
DataFlowSolver solver;
100-
solver.load<DeadCodeAnalysis>();
101-
solver.load<IntegerRangeAnalysis>();
102-
if (failed(solver.initializeAndRun(
103-
forOp->getParentOfType<FunctionOpInterface>())))
104-
return;
105-
106-
// `b` is used to create affineExpr
107-
Builder b(forOp.getContext());
108-
for (unsigned i = numDim, e = operands.size(); i < e; ++i) {
109-
Value operand = operands[i];
110-
auto lattice =
111-
solver.lookupState<dataflow::IntegerValueRangeLattice>(operand);
112-
if (!lattice) {
113-
symReplacements.push_back(b.getAffineSymbolExpr(i - numDim));
114-
continue;
115-
}
116-
117-
if (lattice->getValue().isUninitialized()) {
118-
symReplacements.push_back(b.getAffineSymbolExpr(i - numDim));
119-
continue;
120-
}
121-
122-
ConstantIntRanges range = lattice->getValue().getValue();
123-
APInt max = range.smax();
124-
APInt min = range.smin();
125-
unsigned bitNums = max.getBitWidth();
126-
127-
if (APInt::getSignedMaxValue(bitNums) == max &&
128-
APInt::getSignedMinValue(bitNums) == min) {
129-
symReplacements.push_back(b.getAffineSymbolExpr(i - numDim));
130-
continue;
131-
}
132-
133-
if (!replaceByMin)
134-
symReplacements.push_back(b.getAffineConstantExpr(max.getZExtValue()));
135-
else
136-
symReplacements.push_back(b.getAffineConstantExpr(min.getZExtValue()));
137-
}
138-
return;
139-
}
140-
14189
/// Take the min if all trip counts are constant.
14290
static std::optional<uint64_t>
143-
getConstantTripCountFromAffineMap(AffineMap map) {
91+
getConstantTripCountFromAffineMap(AffineMap map,
92+
SmallVectorImpl<Value> &operands,
93+
presburger::BoundType type) {
14494
std::optional<uint64_t> tripCount;
14595
for (auto resultExpr : map.getResults()) {
146-
auto constExpr = dyn_cast<AffineConstantExpr>(resultExpr);
147-
if (!constExpr)
96+
AffineMap subMap =
97+
AffineMap::get(map.getNumDims(), map.getNumSymbols(), resultExpr);
98+
ValueBoundsConstraintSet::Variable var(subMap, operands);
99+
auto lbBound = ValueBoundsConstraintSet::computeConstantBound(
100+
mlir::presburger::BoundType::LB, var);
101+
auto ubBound = ValueBoundsConstraintSet::computeConstantBound(
102+
mlir::presburger::BoundType::UB, var, nullptr, true);
103+
if (failed(lbBound) || failed(ubBound))
148104
return std::nullopt;
149-
if (tripCount.has_value())
150-
tripCount =
151-
std::min(*tripCount, static_cast<uint64_t>(constExpr.getValue()));
152-
else
153-
tripCount = constExpr.getValue();
105+
if (type == presburger::BoundType::LB) {
106+
if (tripCount.has_value())
107+
tripCount =
108+
std::min(*tripCount, static_cast<uint64_t>(lbBound.value()));
109+
else
110+
tripCount = lbBound.value();
111+
} else if (type == presburger::BoundType::UB) {
112+
if (tripCount.has_value())
113+
tripCount =
114+
std::min(*tripCount, static_cast<uint64_t>(ubBound.value()));
115+
else
116+
tripCount = ubBound.value();
117+
} else {
118+
return std::nullopt;
119+
}
154120
}
155121
return tripCount;
156122
}
@@ -166,11 +132,8 @@ std::optional<uint64_t> mlir::affine::getConstantTripCount(AffineForOp forOp) {
166132

167133
if (!map)
168134
return std::nullopt;
169-
SmallVector<AffineExpr, 4> symReplacements;
170-
replaceOperandByRange(forOp, operands, symReplacements, map.getNumDims());
171-
map = map.replaceDimsAndSymbols({}, symReplacements, map.getNumDims(),
172-
map.getNumSymbols());
173-
return getConstantTripCountFromAffineMap(map);
135+
return getConstantTripCountFromAffineMap(map, operands,
136+
presburger::BoundType::LB);
174137
}
175138

176139
/// Returns the maximum trip count when the operand of forOp has a range. If the
@@ -184,12 +147,8 @@ mlir::affine::getUpperBoundOnTripCount(AffineForOp forOp) {
184147

185148
if (!map)
186149
return std::nullopt;
187-
SmallVector<AffineExpr, 4> symReplacements;
188-
replaceOperandByRange(forOp, operands, symReplacements, map.getNumDims(),
189-
true);
190-
map = map.replaceDimsAndSymbols({}, symReplacements, map.getNumDims(),
191-
map.getNumSymbols());
192-
return getConstantTripCountFromAffineMap(map);
150+
return getConstantTripCountFromAffineMap(map, operands,
151+
presburger::BoundType::UB);
193152
}
194153

195154
/// Returns the greatest known integral divisor of the trip count. Affine
@@ -202,18 +161,20 @@ uint64_t mlir::affine::getLargestDivisorOfTripCount(AffineForOp forOp) {
202161

203162
if (!map)
204163
return 1;
205-
SmallVector<AffineExpr, 4> symReplacements;
206-
replaceOperandByRange(forOp, operands, symReplacements, map.getNumDims());
207-
map = map.replaceDimsAndSymbols({}, symReplacements, map.getNumDims(),
208-
map.getNumSymbols());
164+
209165
// The largest divisor of the trip count is the GCD of the individual largest
210166
// divisors.
211167
assert(map.getNumResults() >= 1 && "expected one or more results");
212168
std::optional<uint64_t> gcd;
213169
for (auto resultExpr : map.getResults()) {
214170
uint64_t thisGcd;
215-
if (auto constExpr = dyn_cast<AffineConstantExpr>(resultExpr)) {
216-
uint64_t tripCount = constExpr.getValue();
171+
AffineMap subMap =
172+
AffineMap::get(map.getNumDims(), map.getNumSymbols(), resultExpr);
173+
ValueBoundsConstraintSet::Variable var(subMap, operands);
174+
auto lbBound = ValueBoundsConstraintSet::computeConstantBound(
175+
mlir::presburger::BoundType::LB, var);
176+
if (!failed(lbBound)) {
177+
uint64_t tripCount = lbBound.value();
217178
// 0 iteration loops (greatest divisor is 2^64 - 1).
218179
if (tripCount == 0)
219180
thisGcd = std::numeric_limits<uint64_t>::max();

mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -250,34 +250,26 @@ void SubgroupSizeOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
250250
void LaunchOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
251251
SetIntRangeFn setResultRange) {
252252
auto setRange = [&](const ConstantIntRanges &argRange, Value dimResult,
253-
Value idxResult, Value size) {
253+
Value idxResult) {
254254
if (argRange.umin().getBitWidth() != IndexType::kInternalStorageBitWidth)
255255
return;
256-
APInt sizeInt;
257-
if (matchPattern(size, m_ConstantInt(&sizeInt))) {
258-
ConstantIntRanges dimRange = ConstantIntRanges::constant(sizeInt);
259-
setResultRange(dimResult, dimRange);
260-
ConstantIntRanges idxRange = getIndexRange(0, sizeInt.getZExtValue() - 1);
261-
setResultRange(idxResult, idxRange);
262-
} else {
263-
ConstantIntRanges dimRange =
264-
argRange.intersection(getIndexRange(1, kMaxDim));
265-
setResultRange(dimResult, dimRange);
266-
ConstantIntRanges idxRange =
267-
getIndexRange(0, dimRange.umax().getZExtValue() - 1);
268-
setResultRange(idxResult, idxRange);
269-
}
256+
ConstantIntRanges dimRange =
257+
argRange.intersection(getIndexRange(1, kMaxDim));
258+
setResultRange(dimResult, dimRange);
259+
ConstantIntRanges idxRange =
260+
getIndexRange(0, dimRange.umax().getZExtValue() - 1);
261+
setResultRange(idxResult, idxRange);
270262
};
271263

272264
argRanges = argRanges.drop_front(getAsyncDependencies().size());
273265
KernelDim3 gridDims = getGridSize();
274266
KernelDim3 blockIds = getBlockIds();
275-
setRange(argRanges[0], gridDims.x, blockIds.x, getGridSizeX());
276-
setRange(argRanges[1], gridDims.y, blockIds.y, getGridSizeY());
277-
setRange(argRanges[2], gridDims.z, blockIds.z, getGridSizeZ());
267+
setRange(argRanges[0], gridDims.x, blockIds.x);
268+
setRange(argRanges[1], gridDims.y, blockIds.y);
269+
setRange(argRanges[2], gridDims.z, blockIds.z);
278270
KernelDim3 blockDims = getBlockSize();
279271
KernelDim3 threadIds = getThreadIds();
280-
setRange(argRanges[3], blockDims.x, threadIds.x, getBlockSizeX());
281-
setRange(argRanges[4], blockDims.y, threadIds.y, getBlockSizeY());
282-
setRange(argRanges[5], blockDims.z, threadIds.z, getBlockSizeZ());
272+
setRange(argRanges[3], blockDims.x, threadIds.x);
273+
setRange(argRanges[4], blockDims.y, threadIds.y);
274+
setRange(argRanges[5], blockDims.z, threadIds.z);
283275
}

mlir/lib/Interfaces/ValueBoundsOpInterface.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,8 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
646646
// Compute constant bound for `valueDim`.
647647
int64_t ubAdjustment = closedUB ? 0 : 1;
648648
if (auto bound = cstr.cstr.getConstantBound64(type, pos))
649-
return type == BoundType::UB ? *bound + ubAdjustment : *bound;
649+
if (bound.has_value())
650+
return type == BoundType::UB ? *bound + ubAdjustment : *bound;
650651
return failure();
651652
}
652653

0 commit comments

Comments
 (0)