1212
1313#include " mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
1414
15- #include " mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
16- #include " mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
1715#include " mlir/Analysis/SliceAnalysis.h"
1816#include " mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
1917#include " mlir/Dialect/Affine/Analysis/AffineStructures.h"
2018#include " mlir/Dialect/Affine/Analysis/NestedMatcher.h"
2119#include " mlir/Dialect/Affine/IR/AffineOps.h"
2220#include " mlir/Dialect/Affine/IR/AffineValueMap.h"
2321#include " mlir/Interfaces/FunctionInterfaces.h"
22+ #include " mlir/Interfaces/ValueBoundsOpInterface.h"
2423#include " llvm/Support/MathExtras.h"
2524
2625#include " llvm/ADT/DenseSet.h"
3332
3433using namespace mlir ;
3534using namespace mlir ::affine;
36- using namespace mlir ::dataflow;
3735
3836#define DEBUG_TYPE " affine-loop-analysis"
3937
@@ -88,69 +86,37 @@ void mlir::affine::getTripCountMapAndOperands(
8886 tripCountValueMap.getOperands ().end ());
8987}
9088
91- // / By running `IntegerRangeAnalysis` to get the ranges of operand, then fill
92- // / the `symReplacements` with range. If `replaceByMin` is set to true,
93- // / construct `replacement` using the smallest value.By default, the largest
94- // / value will be used for constructing `replacement`.
95- static void replaceOperandByRange (AffineForOp forOp,
96- SmallVectorImpl<Value> &operands,
97- SmallVectorImpl<AffineExpr> &symReplacements,
98- unsigned numDim, bool replaceByMin = false ) {
99- DataFlowSolver solver;
100- solver.load <DeadCodeAnalysis>();
101- solver.load <IntegerRangeAnalysis>();
102- if (failed (solver.initializeAndRun (
103- forOp->getParentOfType <FunctionOpInterface>())))
104- return ;
105-
106- // `b` is used to create affineExpr
107- Builder b (forOp.getContext ());
108- for (unsigned i = numDim, e = operands.size (); i < e; ++i) {
109- Value operand = operands[i];
110- auto lattice =
111- solver.lookupState <dataflow::IntegerValueRangeLattice>(operand);
112- if (!lattice) {
113- symReplacements.push_back (b.getAffineSymbolExpr (i - numDim));
114- continue ;
115- }
116-
117- if (lattice->getValue ().isUninitialized ()) {
118- symReplacements.push_back (b.getAffineSymbolExpr (i - numDim));
119- continue ;
120- }
121-
122- ConstantIntRanges range = lattice->getValue ().getValue ();
123- APInt max = range.smax ();
124- APInt min = range.smin ();
125- unsigned bitNums = max.getBitWidth ();
126-
127- if (APInt::getSignedMaxValue (bitNums) == max &&
128- APInt::getSignedMinValue (bitNums) == min) {
129- symReplacements.push_back (b.getAffineSymbolExpr (i - numDim));
130- continue ;
131- }
132-
133- if (!replaceByMin)
134- symReplacements.push_back (b.getAffineConstantExpr (max.getZExtValue ()));
135- else
136- symReplacements.push_back (b.getAffineConstantExpr (min.getZExtValue ()));
137- }
138- return ;
139- }
140-
14189// / Take the min if all trip counts are constant.
14290static std::optional<uint64_t >
143- getConstantTripCountFromAffineMap (AffineMap map) {
91+ getConstantTripCountFromAffineMap (AffineMap map,
92+ SmallVectorImpl<Value> &operands,
93+ presburger::BoundType type) {
14494 std::optional<uint64_t > tripCount;
14595 for (auto resultExpr : map.getResults ()) {
146- auto constExpr = dyn_cast<AffineConstantExpr>(resultExpr);
147- if (!constExpr)
96+ AffineMap subMap =
97+ AffineMap::get (map.getNumDims (), map.getNumSymbols (), resultExpr);
98+ ValueBoundsConstraintSet::Variable var (subMap, operands);
99+ auto lbBound = ValueBoundsConstraintSet::computeConstantBound (
100+ mlir::presburger::BoundType::LB, var);
101+ auto ubBound = ValueBoundsConstraintSet::computeConstantBound (
102+ mlir::presburger::BoundType::UB, var, nullptr , true );
103+ if (failed (lbBound) || failed (ubBound))
148104 return std::nullopt ;
149- if (tripCount.has_value ())
150- tripCount =
151- std::min (*tripCount, static_cast <uint64_t >(constExpr.getValue ()));
152- else
153- tripCount = constExpr.getValue ();
105+ if (type == presburger::BoundType::LB) {
106+ if (tripCount.has_value ())
107+ tripCount =
108+ std::min (*tripCount, static_cast <uint64_t >(lbBound.value ()));
109+ else
110+ tripCount = lbBound.value ();
111+ } else if (type == presburger::BoundType::UB) {
112+ if (tripCount.has_value ())
113+ tripCount =
114+ std::min (*tripCount, static_cast <uint64_t >(ubBound.value ()));
115+ else
116+ tripCount = ubBound.value ();
117+ } else {
118+ return std::nullopt ;
119+ }
154120 }
155121 return tripCount;
156122}
@@ -166,11 +132,8 @@ std::optional<uint64_t> mlir::affine::getConstantTripCount(AffineForOp forOp) {
166132
167133 if (!map)
168134 return std::nullopt ;
169- SmallVector<AffineExpr, 4 > symReplacements;
170- replaceOperandByRange (forOp, operands, symReplacements, map.getNumDims ());
171- map = map.replaceDimsAndSymbols ({}, symReplacements, map.getNumDims (),
172- map.getNumSymbols ());
173- return getConstantTripCountFromAffineMap (map);
135+ return getConstantTripCountFromAffineMap (map, operands,
136+ presburger::BoundType::LB);
174137}
175138
176139// / Returns the maximum trip count when the operand of forOp has a range. If the
@@ -184,12 +147,8 @@ mlir::affine::getUpperBoundOnTripCount(AffineForOp forOp) {
184147
185148 if (!map)
186149 return std::nullopt ;
187- SmallVector<AffineExpr, 4 > symReplacements;
188- replaceOperandByRange (forOp, operands, symReplacements, map.getNumDims (),
189- true );
190- map = map.replaceDimsAndSymbols ({}, symReplacements, map.getNumDims (),
191- map.getNumSymbols ());
192- return getConstantTripCountFromAffineMap (map);
150+ return getConstantTripCountFromAffineMap (map, operands,
151+ presburger::BoundType::UB);
193152}
194153
195154// / Returns the greatest known integral divisor of the trip count. Affine
@@ -202,18 +161,20 @@ uint64_t mlir::affine::getLargestDivisorOfTripCount(AffineForOp forOp) {
202161
203162 if (!map)
204163 return 1 ;
205- SmallVector<AffineExpr, 4 > symReplacements;
206- replaceOperandByRange (forOp, operands, symReplacements, map.getNumDims ());
207- map = map.replaceDimsAndSymbols ({}, symReplacements, map.getNumDims (),
208- map.getNumSymbols ());
164+
209165 // The largest divisor of the trip count is the GCD of the individual largest
210166 // divisors.
211167 assert (map.getNumResults () >= 1 && " expected one or more results" );
212168 std::optional<uint64_t > gcd;
213169 for (auto resultExpr : map.getResults ()) {
214170 uint64_t thisGcd;
215- if (auto constExpr = dyn_cast<AffineConstantExpr>(resultExpr)) {
216- uint64_t tripCount = constExpr.getValue ();
171+ AffineMap subMap =
172+ AffineMap::get (map.getNumDims (), map.getNumSymbols (), resultExpr);
173+ ValueBoundsConstraintSet::Variable var (subMap, operands);
174+ auto lbBound = ValueBoundsConstraintSet::computeConstantBound (
175+ mlir::presburger::BoundType::LB, var);
176+ if (!failed (lbBound)) {
177+ uint64_t tripCount = lbBound.value ();
217178 // 0 iteration loops (greatest divisor is 2^64 - 1).
218179 if (tripCount == 0 )
219180 thisGcd = std::numeric_limits<uint64_t >::max ();
0 commit comments