Skip to content

Commit 75469bb

Browse files
authored
[MLIR] Add a getStaticTripCount method to LoopLikeOpInterface (#158679)
This patch adds a `getStaticTripCount` to the LoopLikeOpInterface, allowing loops to optionally return a static trip count when possible. This is implemented on SCF ForOp, revamping the implementation of `constantTripCount`, removing redundant duplicate implementations from SCF.cpp.
1 parent c87be72 commit 75469bb

File tree

10 files changed

+945
-98
lines changed

10 files changed

+945
-98
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def ForOp : SCF_Op<"for",
152152
[AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
153153
["getInitsMutable", "getLoopResults", "getRegionIterArgs",
154154
"getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
155-
"getLoopUpperBounds", "getYieldedValuesMutable",
155+
"getLoopUpperBounds", "getStaticTripCount", "getYieldedValuesMutable",
156156
"promoteIfSingleIteration", "replaceWithAdditionalYields",
157157
"yieldTiledValuesAndReplace"]>,
158158
AllTypesMatch<["lowerBound", "upperBound", "step"]>,

mlir/include/mlir/Dialect/Utils/StaticValueUtils.h

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,10 @@ OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val);
105105
SmallVector<OpFoldResult> getAsIndexOpFoldResult(MLIRContext *ctx,
106106
ArrayRef<int64_t> values);
107107

108+
/// If ofr is a constant integer or an IntegerAttr, return the integer.
109+
/// The second return value indicates whether the value is an index type
110+
/// and thus the bitwidth is not defined (the APInt will be set with 64bits).
111+
std::optional<std::pair<APInt, bool>> getConstantAPIntValue(OpFoldResult ofr);
108112
/// If ofr is a constant integer or an IntegerAttr, return the integer.
109113
std::optional<int64_t> getConstantIntValue(OpFoldResult ofr);
110114
/// If all ofrs are constant integers or IntegerAttrs, return the integers.
@@ -201,9 +205,26 @@ foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes);
201205
LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides);
202206

203207
/// Return the number of iterations for a loop with a lower bound `lb`, upper
204-
/// bound `ub` and step `step`.
205-
std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
206-
OpFoldResult step);
208+
/// bound `ub` and step `step`. The `isSigned` flag indicates whether the loop
209+
/// comparison between lb and ub is signed or unsigned. A negative step or a
210+
/// lower bound greater than the upper bound are considered invalid and will
211+
/// yield a zero trip count.
212+
/// The `computeUbMinusLb` callback is invoked to compute the difference between
213+
/// the upper and lower bound when not constant. It can be used by the client
214+
/// to compute a static difference when the bounds are not constant.
215+
///
216+
/// For example, the following code:
217+
///
218+
/// %ub = arith.addi nsw %lb, %c16_i32 : i32
219+
/// %1 = scf.for %arg0 = %lb to %ub ...
220+
///
221+
/// where %ub is computed as a static offset from %lb.
222+
/// Note: the matched addition should be nsw/nuw (matching the loop comparison)
223+
/// to avoid overflow, otherwise an overflow would imply a zero trip count.
224+
std::optional<APInt> constantTripCount(
225+
OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned,
226+
llvm::function_ref<std::optional<llvm::APSInt>(Value, Value, bool)>
227+
computeUbMinusLb);
207228

208229
/// Idiomatic saturated operations on values like offsets, sizes, and strides.
209230
struct SaturatedInteger {

mlir/include/mlir/Interfaces/LoopLikeInterface.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,17 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
232232
/*defaultImplementation=*/[{
233233
return ::mlir::failure();
234234
}]
235+
>,
236+
InterfaceMethod<[{
237+
Compute the static trip count if possible.
238+
}],
239+
/*retTy=*/"::std::optional<APInt>",
240+
/*methodName=*/"getStaticTripCount",
241+
/*args=*/(ins),
242+
/*methodBody=*/"",
243+
/*defaultImplementation=*/[{
244+
return ::std::nullopt;
245+
}]
235246
>
236247
];
237248

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 62 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,18 @@
1919
#include "mlir/IR/BuiltinAttributes.h"
2020
#include "mlir/IR/IRMapping.h"
2121
#include "mlir/IR/Matchers.h"
22+
#include "mlir/IR/Operation.h"
23+
#include "mlir/IR/OperationSupport.h"
2224
#include "mlir/IR/PatternMatch.h"
2325
#include "mlir/Interfaces/FunctionInterfaces.h"
2426
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
2527
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
2628
#include "mlir/Transforms/InliningUtils.h"
2729
#include "llvm/ADT/MapVector.h"
2830
#include "llvm/ADT/SmallPtrSet.h"
31+
#include "llvm/Support/Casting.h"
32+
#include "llvm/Support/DebugLog.h"
33+
#include <optional>
2934

3035
using namespace mlir;
3136
using namespace mlir::scf;
@@ -105,6 +110,24 @@ static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
105110
return nullptr;
106111
}
107112

113+
/// Helper function to compute the difference between two values. This is used
114+
/// by the loop implementations to compute the trip count.
115+
static std::optional<llvm::APSInt> computeUbMinusLb(Value lb, Value ub,
116+
bool isSigned) {
117+
llvm::APSInt diff;
118+
auto addOp = ub.getDefiningOp<arith::AddIOp>();
119+
if (!addOp)
120+
return std::nullopt;
121+
if ((isSigned && !addOp.hasNoSignedWrap()) ||
122+
(!isSigned && !addOp.hasNoUnsignedWrap()))
123+
return std::nullopt;
124+
125+
if (addOp.getLhs() != lb ||
126+
!matchPattern(addOp.getRhs(), m_ConstantInt(&diff)))
127+
return std::nullopt;
128+
return diff;
129+
}
130+
108131
//===----------------------------------------------------------------------===//
109132
// ExecuteRegionOp
110133
//===----------------------------------------------------------------------===//
@@ -408,11 +431,19 @@ std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
408431
/// Promotes the loop body of a forOp to its containing block if the forOp
409432
/// it can be determined that the loop has a single iteration.
410433
LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
411-
std::optional<int64_t> tripCount =
412-
constantTripCount(getLowerBound(), getUpperBound(), getStep());
413-
if (!tripCount.has_value() || tripCount != 1)
434+
std::optional<APInt> tripCount = getStaticTripCount();
435+
LDBG() << "promoteIfSingleIteration tripCount is " << tripCount
436+
<< " for loop "
437+
<< OpWithFlags(getOperation(), OpPrintingFlags().skipRegions());
438+
if (!tripCount.has_value() || tripCount->getSExtValue() > 1)
414439
return failure();
415440

441+
if (*tripCount == 0) {
442+
rewriter.replaceAllUsesWith(getResults(), getInitArgs());
443+
rewriter.eraseOp(*this);
444+
return success();
445+
}
446+
416447
// Replace all results with the yielded values.
417448
auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
418449
rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
@@ -646,7 +677,8 @@ SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
646677
LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
647678
for (auto [lb, ub, step] :
648679
llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
649-
auto tripCount = constantTripCount(lb, ub, step);
680+
auto tripCount =
681+
constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
650682
if (!tripCount.has_value() || *tripCount != 1)
651683
return failure();
652684
}
@@ -1003,27 +1035,6 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
10031035
}
10041036
};
10051037

1006-
/// Util function that tries to compute a constant diff between u and l.
1007-
/// Returns std::nullopt when the difference between two AffineValueMap is
1008-
/// dynamic.
1009-
static std::optional<APInt> computeConstDiff(Value l, Value u) {
1010-
IntegerAttr clb, cub;
1011-
if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
1012-
llvm::APInt lbValue = clb.getValue();
1013-
llvm::APInt ubValue = cub.getValue();
1014-
return ubValue - lbValue;
1015-
}
1016-
1017-
// Else a simple pattern match for x + c or c + x
1018-
llvm::APInt diff;
1019-
if (matchPattern(
1020-
u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) ||
1021-
matchPattern(
1022-
u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l))))
1023-
return diff;
1024-
return std::nullopt;
1025-
}
1026-
10271038
/// Rewriting pattern that erases loops that are known not to iterate, replaces
10281039
/// single-iteration loops with their bodies, and removes empty loops that
10291040
/// iterate at least once and only return values defined outside of the loop.
@@ -1032,34 +1043,21 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
10321043

10331044
LogicalResult matchAndRewrite(ForOp op,
10341045
PatternRewriter &rewriter) const override {
1035-
// If the upper bound is the same as the lower bound, the loop does not
1036-
// iterate, just remove it.
1037-
if (op.getLowerBound() == op.getUpperBound()) {
1046+
std::optional<APInt> tripCount = op.getStaticTripCount();
1047+
if (!tripCount.has_value())
1048+
return rewriter.notifyMatchFailure(op,
1049+
"can't compute constant trip count");
1050+
1051+
if (tripCount->isZero()) {
1052+
LDBG() << "SimplifyTrivialLoops tripCount is 0 for loop "
1053+
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
10381054
rewriter.replaceOp(op, op.getInitArgs());
10391055
return success();
10401056
}
10411057

1042-
std::optional<APInt> diff =
1043-
computeConstDiff(op.getLowerBound(), op.getUpperBound());
1044-
if (!diff)
1045-
return failure();
1046-
1047-
// If the loop is known to have 0 iterations, remove it.
1048-
bool zeroOrLessIterations =
1049-
diff->isZero() || (!op.getUnsignedCmp() && diff->isNegative());
1050-
if (zeroOrLessIterations) {
1051-
rewriter.replaceOp(op, op.getInitArgs());
1052-
return success();
1053-
}
1054-
1055-
std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
1056-
if (!maybeStepValue)
1057-
return failure();
1058-
1059-
// If the loop is known to have 1 iteration, inline its body and remove the
1060-
// loop.
1061-
llvm::APInt stepValue = *maybeStepValue;
1062-
if (stepValue.sge(*diff)) {
1058+
if (tripCount->getSExtValue() == 1) {
1059+
LDBG() << "SimplifyTrivialLoops tripCount is 1 for loop "
1060+
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
10631061
SmallVector<Value, 4> blockArgs;
10641062
blockArgs.reserve(op.getInitArgs().size() + 1);
10651063
blockArgs.push_back(op.getLowerBound());
@@ -1072,11 +1070,14 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
10721070
Block &block = op.getRegion().front();
10731071
if (!llvm::hasSingleElement(block))
10741072
return failure();
1075-
// If the loop is empty, iterates at least once, and only returns values
1073+
// The loop is empty and iterates at least once, if it only returns values
10761074
// defined outside of the loop, remove it and replace it with yield values.
10771075
if (llvm::any_of(op.getYieldedValues(),
10781076
[&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
10791077
return failure();
1078+
LDBG() << "SimplifyTrivialLoops empty body loop allows replacement with "
1079+
"yield operands for loop "
1080+
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
10801081
rewriter.replaceOp(op, op.getYieldedValues());
10811082
return success();
10821083
}
@@ -1172,6 +1173,11 @@ Speculation::Speculatability ForOp::getSpeculatability() {
11721173
return Speculation::NotSpeculatable;
11731174
}
11741175

1176+
std::optional<APInt> ForOp::getStaticTripCount() {
1177+
return constantTripCount(getLowerBound(), getUpperBound(), getStep(),
1178+
/*isSigned=*/!getUnsignedCmp(), computeUbMinusLb);
1179+
}
1180+
11751181
//===----------------------------------------------------------------------===//
11761182
// ForallOp
11771183
//===----------------------------------------------------------------------===//
@@ -1768,7 +1774,8 @@ struct ForallOpSingleOrZeroIterationDimsFolder
17681774
for (auto [lb, ub, step, iv] :
17691775
llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
17701776
op.getMixedStep(), op.getInductionVars())) {
1771-
auto numIterations = constantTripCount(lb, ub, step);
1777+
auto numIterations =
1778+
constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
17721779
if (numIterations.has_value()) {
17731780
// Remove the loop if it performs zero iterations.
17741781
if (*numIterations == 0) {
@@ -1839,7 +1846,8 @@ struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
18391846
op.getMixedStep(), op.getInductionVars())) {
18401847
if (iv.hasNUses(0))
18411848
continue;
1842-
auto numIterations = constantTripCount(lb, ub, step);
1849+
auto numIterations =
1850+
constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
18431851
if (!numIterations.has_value() || numIterations.value() != 1) {
18441852
continue;
18451853
}
@@ -3084,7 +3092,8 @@ struct ParallelOpSingleOrZeroIterationDimsFolder
30843092
for (auto [lb, ub, step, iv] :
30853093
llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
30863094
op.getInductionVars())) {
3087-
auto numIterations = constantTripCount(lb, ub, step);
3095+
auto numIterations =
3096+
constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
30883097
if (numIterations.has_value()) {
30893098
// Remove the loop if it performs zero iterations.
30903099
if (*numIterations == 0) {

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir/IR/PatternMatch.h"
2323
#include "mlir/Interfaces/SideEffectInterfaces.h"
2424
#include "mlir/Transforms/RegionUtils.h"
25+
#include "llvm/ADT/APInt.h"
2526
#include "llvm/ADT/STLExtras.h"
2627
#include "llvm/ADT/SmallVector.h"
2728
#include "llvm/Support/DebugLog.h"
@@ -290,14 +291,6 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
290291
return arith::DivUIOp::create(builder, loc, sum, divisor);
291292
}
292293

293-
/// Returns the trip count of `forOp` if its' low bound, high bound and step are
294-
/// constants, or optional otherwise. Trip count is computed as
295-
/// ceilDiv(highBound - lowBound, step).
296-
static std::optional<int64_t> getConstantTripCount(scf::ForOp forOp) {
297-
return constantTripCount(forOp.getLowerBound(), forOp.getUpperBound(),
298-
forOp.getStep());
299-
}
300-
301294
/// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
302295
/// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
303296
/// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
@@ -376,7 +369,7 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
376369
Value stepUnrolled;
377370
bool generateEpilogueLoop = true;
378371

379-
std::optional<int64_t> constTripCount = getConstantTripCount(forOp);
372+
std::optional<APInt> constTripCount = forOp.getStaticTripCount();
380373
if (constTripCount) {
381374
// Constant loop bounds computation.
382375
int64_t lbCst = getConstantIntValue(forOp.getLowerBound()).value();
@@ -390,7 +383,8 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
390383
}
391384

392385
int64_t tripCountEvenMultiple =
393-
*constTripCount - (*constTripCount % unrollFactor);
386+
constTripCount->getSExtValue() -
387+
(constTripCount->getSExtValue() % unrollFactor);
394388
int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
395389
int64_t stepUnrolledCst = stepCst * unrollFactor;
396390

@@ -486,15 +480,15 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
486480
/// Unrolls this loop completely.
487481
LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
488482
IRRewriter rewriter(forOp.getContext());
489-
std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
483+
std::optional<APInt> mayBeConstantTripCount = forOp.getStaticTripCount();
490484
if (!mayBeConstantTripCount.has_value())
491485
return failure();
492-
uint64_t tripCount = *mayBeConstantTripCount;
493-
if (tripCount == 0)
486+
APInt &tripCount = *mayBeConstantTripCount;
487+
if (tripCount.isZero())
494488
return success();
495-
if (tripCount == 1)
489+
if (tripCount.getSExtValue() == 1)
496490
return forOp.promoteIfSingleIteration(rewriter);
497-
return loopUnrollByFactor(forOp, tripCount);
491+
return loopUnrollByFactor(forOp, tripCount.getSExtValue());
498492
}
499493

500494
/// Check if bounds of all inner loops are defined outside of `forOp`
@@ -534,18 +528,18 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
534528

535529
// Currently, only constant trip count that divided by the unroll factor is
536530
// supported.
537-
std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
531+
std::optional<APInt> tripCount = forOp.getStaticTripCount();
538532
if (!tripCount.has_value()) {
539533
// If the trip count is dynamic, do not unroll & jam.
540534
LDBG() << "failed to unroll and jam: trip count could not be determined";
541535
return failure();
542536
}
543-
if (unrollJamFactor > *tripCount) {
537+
if (unrollJamFactor > tripCount->getZExtValue()) {
544538
LDBG() << "unroll and jam factor is greater than trip count, set factor to "
545539
"trip "
546540
"count";
547-
unrollJamFactor = *tripCount;
548-
} else if (*tripCount % unrollJamFactor != 0) {
541+
unrollJamFactor = tripCount->getZExtValue();
542+
} else if (tripCount->getSExtValue() % unrollJamFactor != 0) {
549543
LDBG() << "failed to unroll and jam: unsupported trip count that is not a "
550544
"multiple of unroll jam factor";
551545
return failure();

0 commit comments

Comments
 (0)