1
1
#include " PassDetails.h"
2
2
3
3
#include " mlir/Dialect/Func/IR/FuncOps.h"
4
+ #include " mlir/Dialect/Math/IR/Math.h"
4
5
#include " mlir/Dialect/SCF/Passes.h"
5
6
#include " mlir/Dialect/SCF/SCF.h"
6
7
#include " mlir/IR/BlockAndValueMapping.h"
@@ -1910,6 +1911,10 @@ struct RemoveUnusedCondVar : public OpRewritePattern<WhileOp> {
1910
1911
auto arg = std::get<0 >(pair);
1911
1912
auto afarg = std::get<1 >(pair);
1912
1913
auto res = std::get<2 >(pair);
1914
+ if (!op.getBefore ().isAncestor (arg.getParentRegion ())) {
1915
+ res.replaceAllUsesWith (arg);
1916
+ afarg.replaceAllUsesWith (arg);
1917
+ }
1913
1918
if (afarg.use_empty () && res.use_empty ()) {
1914
1919
eraseArgs.push_back ((unsigned )i);
1915
1920
} else if (valueOffsets.find (arg.getAsOpaquePointer ()) !=
@@ -2088,6 +2093,129 @@ struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
2088
2093
}
2089
2094
};
2090
2095
2096
+ // If and and with something is preventing creating a for
2097
+ // move the and into the after body guarded by an if
2098
+ struct WhileShiftToInduction : public OpRewritePattern <WhileOp> {
2099
+ using OpRewritePattern<WhileOp>::OpRewritePattern;
2100
+
2101
+ LogicalResult matchAndRewrite (WhileOp loop,
2102
+ PatternRewriter &rewriter) const override {
2103
+ auto condOp = loop.getConditionOp ();
2104
+
2105
+ if (!llvm::hasNItems (loop.getBefore ().back (), 2 ))
2106
+ return failure ();
2107
+
2108
+ auto cmpIOp = condOp.getCondition ().getDefiningOp <CmpIOp>();
2109
+ if (!cmpIOp) {
2110
+ return failure ();
2111
+ }
2112
+
2113
+ if (cmpIOp.getPredicate () != CmpIPredicate::ugt)
2114
+ return failure ();
2115
+
2116
+ if (!matchPattern (cmpIOp.getRhs (), m_Zero ()))
2117
+ return failure ();
2118
+
2119
+ auto indVar = cmpIOp.getLhs ().dyn_cast <BlockArgument>();
2120
+ if (!indVar)
2121
+ return failure ();
2122
+
2123
+ if (indVar.getOwner () != &loop.getBefore ().front ())
2124
+ return failure ();
2125
+
2126
+ auto endYield = cast<YieldOp>(loop.getAfter ().back ().getTerminator ());
2127
+
2128
+ // Check that the block argument is actually an induction var:
2129
+ // Namely, its next value adds to the previous with an invariant step.
2130
+ auto shiftOp =
2131
+ endYield.getResults ()[indVar.getArgNumber ()].getDefiningOp <ShRUIOp>();
2132
+ if (!shiftOp)
2133
+ return failure ();
2134
+
2135
+ if (!matchPattern (shiftOp.getRhs (), m_One ()))
2136
+ return failure ();
2137
+
2138
+ auto prevIndVar = shiftOp.getLhs ().dyn_cast <BlockArgument>();
2139
+ if (!prevIndVar)
2140
+ return failure ();
2141
+
2142
+ if (prevIndVar.getOwner () != &loop.getAfter ().front ())
2143
+ return failure ();
2144
+
2145
+ if (condOp.getOperand (1 + prevIndVar.getArgNumber ()) != indVar)
2146
+ return failure ();
2147
+
2148
+ auto startingV = loop.getInits ()[indVar.getArgNumber ()];
2149
+
2150
+ Value lz =
2151
+ rewriter.create <math::CountLeadingZerosOp>(loop.getLoc (), startingV);
2152
+ if (!lz.getType ().isIndex ())
2153
+ lz = rewriter.create <IndexCastOp>(loop.getLoc (), rewriter.getIndexType (),
2154
+ lz);
2155
+
2156
+ auto len = rewriter.create <SubIOp>(
2157
+ loop.getLoc (),
2158
+ rewriter.create <ConstantIndexOp>(
2159
+ loop.getLoc (), indVar.getType ().getIntOrFloatBitWidth ()),
2160
+ lz);
2161
+
2162
+ SmallVector<Value> newInits (loop.getInits ());
2163
+ newInits[indVar.getArgNumber ()] =
2164
+ rewriter.create <ConstantIndexOp>(loop.getLoc (), 0 );
2165
+ SmallVector<Type> postTys (loop.getResultTypes ());
2166
+ postTys.push_back (rewriter.getIndexType ());
2167
+
2168
+ auto newWhile = rewriter.create <WhileOp>(loop.getLoc (), postTys, newInits);
2169
+ rewriter.createBlock (&newWhile.getBefore ());
2170
+
2171
+ BlockAndValueMapping map;
2172
+ Value newIndVar;
2173
+ for (auto a : loop.getBefore ().front ().getArguments ()) {
2174
+ auto arg = newWhile.getBefore ().addArgument (
2175
+ a == indVar ? rewriter.getIndexType () : a.getType (), a.getLoc ());
2176
+ if (a != indVar)
2177
+ map.map (a, arg);
2178
+ else
2179
+ newIndVar = arg;
2180
+ }
2181
+
2182
+ rewriter.setInsertionPointToEnd (&newWhile.getBefore ().front ());
2183
+ Value newCmp = rewriter.create <CmpIOp>(cmpIOp.getLoc (), CmpIPredicate::ult,
2184
+ newIndVar, len);
2185
+ map.map (cmpIOp, newCmp);
2186
+
2187
+ Value newIndVarTyped = newIndVar;
2188
+ if (newIndVarTyped.getType () != indVar.getType ())
2189
+ newIndVarTyped = rewriter.create <arith::IndexCastOp>(
2190
+ shiftOp.getLoc (), indVar.getType (), newIndVar);
2191
+ map.map (indVar, rewriter.create <ShRUIOp>(shiftOp.getLoc (), startingV,
2192
+ newIndVarTyped));
2193
+ SmallVector<Value> remapped;
2194
+ for (auto o : condOp.getArgs ())
2195
+ remapped.push_back (map.lookup (o));
2196
+ remapped.push_back (newIndVar);
2197
+ rewriter.create <ConditionOp>(condOp.getLoc (), newCmp, remapped);
2198
+
2199
+ newWhile.getAfter ().takeBody (loop.getAfter ());
2200
+
2201
+ auto newPostInd = newWhile.getAfter ().front ().addArgument (
2202
+ rewriter.getIndexType (), loop.getLoc ());
2203
+ auto yieldOp =
2204
+ cast<scf::YieldOp>(newWhile.getAfter ().front ().getTerminator ());
2205
+ SmallVector<Value> yields (yieldOp.getOperands ());
2206
+ rewriter.setInsertionPointToEnd (&newWhile.getAfter ().front ());
2207
+ yields[indVar.getArgNumber ()] = rewriter.create <AddIOp>(
2208
+ loop.getLoc (), newPostInd,
2209
+ rewriter.create <arith::ConstantIndexOp>(loop.getLoc (), 1 ));
2210
+ rewriter.replaceOpWithNewOp <scf::YieldOp>(yieldOp, yields);
2211
+
2212
+ SmallVector<Value> res (newWhile.getResults ());
2213
+ res.pop_back ();
2214
+ rewriter.replaceOp (loop, res);
2215
+ return success ();
2216
+ }
2217
+ };
2218
+
2091
2219
void CanonicalizeFor::runOnOperation () {
2092
2220
mlir::RewritePatternSet rpl (getOperation ()->getContext ());
2093
2221
rpl.add <PropagateInLoopBody, ForOpInductionReplacement, RemoveUnusedArgs,
@@ -2097,6 +2225,8 @@ void CanonicalizeFor::runOnOperation() {
2097
2225
2098
2226
ReplaceRedundantArgs,
2099
2227
2228
+ WhileShiftToInduction,
2229
+
2100
2230
ForBreakAddUpgrade, RemoveUnusedResults,
2101
2231
2102
2232
MoveWhileAndDown, MoveWhileDown3, MoveWhileInvariantIfResult,
0 commit comments