Skip to content

Commit efc0215

Browse files
authored
Add shift for canonicalization (#255)
1 parent ffe61f9 commit efc0215

File tree

2 files changed

+164
-0
lines changed

2 files changed

+164
-0
lines changed

lib/polygeist/Passes/CanonicalizeFor.cpp

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "PassDetails.h"
22

33
#include "mlir/Dialect/Func/IR/FuncOps.h"
4+
#include "mlir/Dialect/Math/IR/Math.h"
45
#include "mlir/Dialect/SCF/Passes.h"
56
#include "mlir/Dialect/SCF/SCF.h"
67
#include "mlir/IR/BlockAndValueMapping.h"
@@ -1910,6 +1911,10 @@ struct RemoveUnusedCondVar : public OpRewritePattern<WhileOp> {
19101911
auto arg = std::get<0>(pair);
19111912
auto afarg = std::get<1>(pair);
19121913
auto res = std::get<2>(pair);
1914+
if (!op.getBefore().isAncestor(arg.getParentRegion())) {
1915+
res.replaceAllUsesWith(arg);
1916+
afarg.replaceAllUsesWith(arg);
1917+
}
19131918
if (afarg.use_empty() && res.use_empty()) {
19141919
eraseArgs.push_back((unsigned)i);
19151920
} else if (valueOffsets.find(arg.getAsOpaquePointer()) !=
@@ -2088,6 +2093,129 @@ struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
20882093
}
20892094
};
20902095

2096+
// If and and with something is preventing creating a for
2097+
// move the and into the after body guarded by an if
2098+
struct WhileShiftToInduction : public OpRewritePattern<WhileOp> {
2099+
using OpRewritePattern<WhileOp>::OpRewritePattern;
2100+
2101+
LogicalResult matchAndRewrite(WhileOp loop,
2102+
PatternRewriter &rewriter) const override {
2103+
auto condOp = loop.getConditionOp();
2104+
2105+
if (!llvm::hasNItems(loop.getBefore().back(), 2))
2106+
return failure();
2107+
2108+
auto cmpIOp = condOp.getCondition().getDefiningOp<CmpIOp>();
2109+
if (!cmpIOp) {
2110+
return failure();
2111+
}
2112+
2113+
if (cmpIOp.getPredicate() != CmpIPredicate::ugt)
2114+
return failure();
2115+
2116+
if (!matchPattern(cmpIOp.getRhs(), m_Zero()))
2117+
return failure();
2118+
2119+
auto indVar = cmpIOp.getLhs().dyn_cast<BlockArgument>();
2120+
if (!indVar)
2121+
return failure();
2122+
2123+
if (indVar.getOwner() != &loop.getBefore().front())
2124+
return failure();
2125+
2126+
auto endYield = cast<YieldOp>(loop.getAfter().back().getTerminator());
2127+
2128+
// Check that the block argument is actually an induction var:
2129+
// Namely, its next value adds to the previous with an invariant step.
2130+
auto shiftOp =
2131+
endYield.getResults()[indVar.getArgNumber()].getDefiningOp<ShRUIOp>();
2132+
if (!shiftOp)
2133+
return failure();
2134+
2135+
if (!matchPattern(shiftOp.getRhs(), m_One()))
2136+
return failure();
2137+
2138+
auto prevIndVar = shiftOp.getLhs().dyn_cast<BlockArgument>();
2139+
if (!prevIndVar)
2140+
return failure();
2141+
2142+
if (prevIndVar.getOwner() != &loop.getAfter().front())
2143+
return failure();
2144+
2145+
if (condOp.getOperand(1 + prevIndVar.getArgNumber()) != indVar)
2146+
return failure();
2147+
2148+
auto startingV = loop.getInits()[indVar.getArgNumber()];
2149+
2150+
Value lz =
2151+
rewriter.create<math::CountLeadingZerosOp>(loop.getLoc(), startingV);
2152+
if (!lz.getType().isIndex())
2153+
lz = rewriter.create<IndexCastOp>(loop.getLoc(), rewriter.getIndexType(),
2154+
lz);
2155+
2156+
auto len = rewriter.create<SubIOp>(
2157+
loop.getLoc(),
2158+
rewriter.create<ConstantIndexOp>(
2159+
loop.getLoc(), indVar.getType().getIntOrFloatBitWidth()),
2160+
lz);
2161+
2162+
SmallVector<Value> newInits(loop.getInits());
2163+
newInits[indVar.getArgNumber()] =
2164+
rewriter.create<ConstantIndexOp>(loop.getLoc(), 0);
2165+
SmallVector<Type> postTys(loop.getResultTypes());
2166+
postTys.push_back(rewriter.getIndexType());
2167+
2168+
auto newWhile = rewriter.create<WhileOp>(loop.getLoc(), postTys, newInits);
2169+
rewriter.createBlock(&newWhile.getBefore());
2170+
2171+
BlockAndValueMapping map;
2172+
Value newIndVar;
2173+
for (auto a : loop.getBefore().front().getArguments()) {
2174+
auto arg = newWhile.getBefore().addArgument(
2175+
a == indVar ? rewriter.getIndexType() : a.getType(), a.getLoc());
2176+
if (a != indVar)
2177+
map.map(a, arg);
2178+
else
2179+
newIndVar = arg;
2180+
}
2181+
2182+
rewriter.setInsertionPointToEnd(&newWhile.getBefore().front());
2183+
Value newCmp = rewriter.create<CmpIOp>(cmpIOp.getLoc(), CmpIPredicate::ult,
2184+
newIndVar, len);
2185+
map.map(cmpIOp, newCmp);
2186+
2187+
Value newIndVarTyped = newIndVar;
2188+
if (newIndVarTyped.getType() != indVar.getType())
2189+
newIndVarTyped = rewriter.create<arith::IndexCastOp>(
2190+
shiftOp.getLoc(), indVar.getType(), newIndVar);
2191+
map.map(indVar, rewriter.create<ShRUIOp>(shiftOp.getLoc(), startingV,
2192+
newIndVarTyped));
2193+
SmallVector<Value> remapped;
2194+
for (auto o : condOp.getArgs())
2195+
remapped.push_back(map.lookup(o));
2196+
remapped.push_back(newIndVar);
2197+
rewriter.create<ConditionOp>(condOp.getLoc(), newCmp, remapped);
2198+
2199+
newWhile.getAfter().takeBody(loop.getAfter());
2200+
2201+
auto newPostInd = newWhile.getAfter().front().addArgument(
2202+
rewriter.getIndexType(), loop.getLoc());
2203+
auto yieldOp =
2204+
cast<scf::YieldOp>(newWhile.getAfter().front().getTerminator());
2205+
SmallVector<Value> yields(yieldOp.getOperands());
2206+
rewriter.setInsertionPointToEnd(&newWhile.getAfter().front());
2207+
yields[indVar.getArgNumber()] = rewriter.create<AddIOp>(
2208+
loop.getLoc(), newPostInd,
2209+
rewriter.create<arith::ConstantIndexOp>(loop.getLoc(), 1));
2210+
rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, yields);
2211+
2212+
SmallVector<Value> res(newWhile.getResults());
2213+
res.pop_back();
2214+
rewriter.replaceOp(loop, res);
2215+
return success();
2216+
}
2217+
};
2218+
20912219
void CanonicalizeFor::runOnOperation() {
20922220
mlir::RewritePatternSet rpl(getOperation()->getContext());
20932221
rpl.add<PropagateInLoopBody, ForOpInductionReplacement, RemoveUnusedArgs,
@@ -2097,6 +2225,8 @@ void CanonicalizeFor::runOnOperation() {
20972225

20982226
ReplaceRedundantArgs,
20992227

2228+
WhileShiftToInduction,
2229+
21002230
ForBreakAddUpgrade, RemoveUnusedResults,
21012231

21022232
MoveWhileAndDown, MoveWhileDown3, MoveWhileInvariantIfResult,

test/polygeist-opt/shiftloop.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// RUN: polygeist-opt --canonicalize-scf-for --split-input-file %s | FileCheck %s
2+
3+
module {
4+
func.func @foo(%arg0: i32) {
5+
%c1_i32 = arith.constant 1 : i32
6+
%c0_i32 = arith.constant 0 : i32
7+
%0 = scf.while (%arg1 = %arg0) : (i32) -> i32 {
8+
%1 = arith.cmpi ugt, %arg1, %c0_i32 : i32
9+
scf.condition(%1) %arg1 : i32
10+
} do {
11+
^bb0(%arg1: i32):
12+
func.call @run(%arg1) : (i32) -> ()
13+
%1 = arith.shrui %arg1, %c1_i32 : i32
14+
scf.yield %1 : i32
15+
}
16+
return
17+
}
18+
func.func private @run(i32) attributes {llvm.linkage = #llvm.linkage<external>}
19+
}
20+
21+
// CHECK: func.func @foo(%arg0: i32)
22+
// CHECK-NEXT: %c32 = arith.constant 32 : index
23+
// CHECK-NEXT: %c0 = arith.constant 0 : index
24+
// CHECK-NEXT: %c1 = arith.constant 1 : index
25+
// CHECK-NEXT: %0 = math.ctlz %arg0 : i32
26+
// CHECK-NEXT: %1 = arith.index_cast %0 : i32 to index
27+
// CHECK-NEXT: %2 = arith.subi %c32, %1 : index
28+
// CHECK-NEXT: scf.for %arg1 = %c0 to %2 step %c1 {
29+
// CHECK-NEXT: %3 = arith.index_cast %arg1 : index to i32
30+
// CHECK-NEXT: %4 = arith.shrui %arg0, %3 : i32
31+
// CHECK-NEXT: func.call @run(%4) : (i32) -> ()
32+
// CHECK-NEXT: }
33+
// CHECK-NEXT: return
34+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)