19
19
#include " mlir/IR/BuiltinAttributes.h"
20
20
#include " mlir/IR/IRMapping.h"
21
21
#include " mlir/IR/Matchers.h"
22
+ #include " mlir/IR/Operation.h"
23
+ #include " mlir/IR/OperationSupport.h"
22
24
#include " mlir/IR/PatternMatch.h"
23
25
#include " mlir/Interfaces/FunctionInterfaces.h"
24
26
#include " mlir/Interfaces/ParallelCombiningOpInterface.h"
25
27
#include " mlir/Interfaces/ValueBoundsOpInterface.h"
26
28
#include " mlir/Transforms/InliningUtils.h"
27
29
#include " llvm/ADT/MapVector.h"
28
30
#include " llvm/ADT/SmallPtrSet.h"
31
+ #include " llvm/Support/Casting.h"
32
+ #include " llvm/Support/DebugLog.h"
33
+ #include < optional>
29
34
30
35
using namespace mlir ;
31
36
using namespace mlir ::scf;
@@ -105,6 +110,24 @@ static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion,
105
110
return nullptr ;
106
111
}
107
112
113
+ // / Helper function to compute the difference between two values. This is used
114
+ // / by the loop implementations to compute the trip count.
115
+ static std::optional<llvm::APSInt> computeUbMinusLb (Value lb, Value ub,
116
+ bool isSigned) {
117
+ llvm::APSInt diff;
118
+ auto addOp = ub.getDefiningOp <arith::AddIOp>();
119
+ if (!addOp)
120
+ return std::nullopt ;
121
+ if ((isSigned && !addOp.hasNoSignedWrap ()) ||
122
+ (!isSigned && !addOp.hasNoUnsignedWrap ()))
123
+ return std::nullopt ;
124
+
125
+ if (addOp.getLhs () != lb ||
126
+ !matchPattern (addOp.getRhs (), m_ConstantInt (&diff)))
127
+ return std::nullopt ;
128
+ return diff;
129
+ }
130
+
108
131
// ===----------------------------------------------------------------------===//
109
132
// ExecuteRegionOp
110
133
// ===----------------------------------------------------------------------===//
@@ -408,11 +431,19 @@ std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
408
431
// / Promotes the loop body of a forOp to its containing block if the forOp
409
432
// / it can be determined that the loop has a single iteration.
410
433
LogicalResult ForOp::promoteIfSingleIteration (RewriterBase &rewriter) {
411
- std::optional<int64_t > tripCount =
412
- constantTripCount (getLowerBound (), getUpperBound (), getStep ());
413
- if (!tripCount.has_value () || tripCount != 1 )
434
+ std::optional<APInt> tripCount = getStaticTripCount ();
435
+ LDBG () << " promoteIfSingleIteration tripCount is " << tripCount
436
+ << " for loop "
437
+ << OpWithFlags (getOperation (), OpPrintingFlags ().skipRegions ());
438
+ if (!tripCount.has_value () || tripCount->getSExtValue () > 1 )
414
439
return failure ();
415
440
441
+ if (*tripCount == 0 ) {
442
+ rewriter.replaceAllUsesWith (getResults (), getInitArgs ());
443
+ rewriter.eraseOp (*this );
444
+ return success ();
445
+ }
446
+
416
447
// Replace all results with the yielded values.
417
448
auto yieldOp = cast<scf::YieldOp>(getBody ()->getTerminator ());
418
449
rewriter.replaceAllUsesWith (getResults (), getYieldedValues ());
@@ -646,7 +677,8 @@ SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
646
677
LogicalResult scf::ForallOp::promoteIfSingleIteration (RewriterBase &rewriter) {
647
678
for (auto [lb, ub, step] :
648
679
llvm::zip (getMixedLowerBound (), getMixedUpperBound (), getMixedStep ())) {
649
- auto tripCount = constantTripCount (lb, ub, step);
680
+ auto tripCount =
681
+ constantTripCount (lb, ub, step, /* isSigned=*/ true , computeUbMinusLb);
650
682
if (!tripCount.has_value () || *tripCount != 1 )
651
683
return failure ();
652
684
}
@@ -1003,27 +1035,6 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
1003
1035
}
1004
1036
};
1005
1037
1006
- // / Util function that tries to compute a constant diff between u and l.
1007
- // / Returns std::nullopt when the difference between two AffineValueMap is
1008
- // / dynamic.
1009
- static std::optional<APInt> computeConstDiff (Value l, Value u) {
1010
- IntegerAttr clb, cub;
1011
- if (matchPattern (l, m_Constant (&clb)) && matchPattern (u, m_Constant (&cub))) {
1012
- llvm::APInt lbValue = clb.getValue ();
1013
- llvm::APInt ubValue = cub.getValue ();
1014
- return ubValue - lbValue;
1015
- }
1016
-
1017
- // Else a simple pattern match for x + c or c + x
1018
- llvm::APInt diff;
1019
- if (matchPattern (
1020
- u, m_Op<arith::AddIOp>(matchers::m_Val (l), m_ConstantInt (&diff))) ||
1021
- matchPattern (
1022
- u, m_Op<arith::AddIOp>(m_ConstantInt (&diff), matchers::m_Val (l))))
1023
- return diff;
1024
- return std::nullopt ;
1025
- }
1026
-
1027
1038
// / Rewriting pattern that erases loops that are known not to iterate, replaces
1028
1039
// / single-iteration loops with their bodies, and removes empty loops that
1029
1040
// / iterate at least once and only return values defined outside of the loop.
@@ -1032,34 +1043,21 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
1032
1043
1033
1044
LogicalResult matchAndRewrite (ForOp op,
1034
1045
PatternRewriter &rewriter) const override {
1035
- // If the upper bound is the same as the lower bound, the loop does not
1036
- // iterate, just remove it.
1037
- if (op.getLowerBound () == op.getUpperBound ()) {
1046
+ std::optional<APInt> tripCount = op.getStaticTripCount ();
1047
+ if (!tripCount.has_value ())
1048
+ return rewriter.notifyMatchFailure (op,
1049
+ " can't compute constant trip count" );
1050
+
1051
+ if (tripCount->isZero ()) {
1052
+ LDBG () << " SimplifyTrivialLoops tripCount is 0 for loop "
1053
+ << OpWithFlags (op, OpPrintingFlags ().skipRegions ());
1038
1054
rewriter.replaceOp (op, op.getInitArgs ());
1039
1055
return success ();
1040
1056
}
1041
1057
1042
- std::optional<APInt> diff =
1043
- computeConstDiff (op.getLowerBound (), op.getUpperBound ());
1044
- if (!diff)
1045
- return failure ();
1046
-
1047
- // If the loop is known to have 0 iterations, remove it.
1048
- bool zeroOrLessIterations =
1049
- diff->isZero () || (!op.getUnsignedCmp () && diff->isNegative ());
1050
- if (zeroOrLessIterations) {
1051
- rewriter.replaceOp (op, op.getInitArgs ());
1052
- return success ();
1053
- }
1054
-
1055
- std::optional<llvm::APInt> maybeStepValue = op.getConstantStep ();
1056
- if (!maybeStepValue)
1057
- return failure ();
1058
-
1059
- // If the loop is known to have 1 iteration, inline its body and remove the
1060
- // loop.
1061
- llvm::APInt stepValue = *maybeStepValue;
1062
- if (stepValue.sge (*diff)) {
1058
+ if (tripCount->getSExtValue () == 1 ) {
1059
+ LDBG () << " SimplifyTrivialLoops tripCount is 1 for loop "
1060
+ << OpWithFlags (op, OpPrintingFlags ().skipRegions ());
1063
1061
SmallVector<Value, 4 > blockArgs;
1064
1062
blockArgs.reserve (op.getInitArgs ().size () + 1 );
1065
1063
blockArgs.push_back (op.getLowerBound ());
@@ -1072,11 +1070,14 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
1072
1070
Block &block = op.getRegion ().front ();
1073
1071
if (!llvm::hasSingleElement (block))
1074
1072
return failure ();
1075
- // If the loop is empty, iterates at least once, and only returns values
1073
+ // The loop is empty and iterates at least once, if it only returns values
1076
1074
// defined outside of the loop, remove it and replace it with yield values.
1077
1075
if (llvm::any_of (op.getYieldedValues (),
1078
1076
[&](Value v) { return !op.isDefinedOutsideOfLoop (v); }))
1079
1077
return failure ();
1078
+ LDBG () << " SimplifyTrivialLoops empty body loop allows replacement with "
1079
+ " yield operands for loop "
1080
+ << OpWithFlags (op, OpPrintingFlags ().skipRegions ());
1080
1081
rewriter.replaceOp (op, op.getYieldedValues ());
1081
1082
return success ();
1082
1083
}
@@ -1172,6 +1173,11 @@ Speculation::Speculatability ForOp::getSpeculatability() {
1172
1173
return Speculation::NotSpeculatable;
1173
1174
}
1174
1175
1176
+ std::optional<APInt> ForOp::getStaticTripCount () {
1177
+ return constantTripCount (getLowerBound (), getUpperBound (), getStep (),
1178
+ /* isSigned=*/ !getUnsignedCmp (), computeUbMinusLb);
1179
+ }
1180
+
1175
1181
// ===----------------------------------------------------------------------===//
1176
1182
// ForallOp
1177
1183
// ===----------------------------------------------------------------------===//
@@ -1768,7 +1774,8 @@ struct ForallOpSingleOrZeroIterationDimsFolder
1768
1774
for (auto [lb, ub, step, iv] :
1769
1775
llvm::zip (op.getMixedLowerBound (), op.getMixedUpperBound (),
1770
1776
op.getMixedStep (), op.getInductionVars ())) {
1771
- auto numIterations = constantTripCount (lb, ub, step);
1777
+ auto numIterations =
1778
+ constantTripCount (lb, ub, step, /* isSigned=*/ true , computeUbMinusLb);
1772
1779
if (numIterations.has_value ()) {
1773
1780
// Remove the loop if it performs zero iterations.
1774
1781
if (*numIterations == 0 ) {
@@ -1839,7 +1846,8 @@ struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
1839
1846
op.getMixedStep (), op.getInductionVars ())) {
1840
1847
if (iv.hasNUses (0 ))
1841
1848
continue ;
1842
- auto numIterations = constantTripCount (lb, ub, step);
1849
+ auto numIterations =
1850
+ constantTripCount (lb, ub, step, /* isSigned=*/ true , computeUbMinusLb);
1843
1851
if (!numIterations.has_value () || numIterations.value () != 1 ) {
1844
1852
continue ;
1845
1853
}
@@ -3084,7 +3092,8 @@ struct ParallelOpSingleOrZeroIterationDimsFolder
3084
3092
for (auto [lb, ub, step, iv] :
3085
3093
llvm::zip (op.getLowerBound (), op.getUpperBound (), op.getStep (),
3086
3094
op.getInductionVars ())) {
3087
- auto numIterations = constantTripCount (lb, ub, step);
3095
+ auto numIterations =
3096
+ constantTripCount (lb, ub, step, /* isSigned=*/ true , computeUbMinusLb);
3088
3097
if (numIterations.has_value ()) {
3089
3098
// Remove the loop if it performs zero iterations.
3090
3099
if (*numIterations == 0 ) {
0 commit comments