Skip to content

Commit 994c12d

Browse files
kumasentoivanradanov
authored andcommitted
[FoldSCFIf] match store between branches
1 parent 4fd974f commit 994c12d

File tree

2 files changed

+116
-3
lines changed

2 files changed

+116
-3
lines changed

tools/polymer/lib/Transforms/FoldSCFIf.cc

Lines changed: 106 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "mlir/Analysis/Utils.h"
99
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1010
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
11+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1112
#include "mlir/Dialect/SCF/SCF.h"
1213
#include "mlir/IR/BlockAndValueMapping.h"
1314
#include "mlir/IR/Builders.h"
@@ -30,11 +31,110 @@ using namespace llvm;
3031

3132
#define DEBUG_TYPE "fold-scf-if"
3233

34+
static bool hasSingleStore(Block *block) {
35+
llvm::SetVector<Value> memrefs;
36+
37+
for (Operation &op : block->getOperations()) {
38+
if (isa<mlir::AffineStoreOp, memref::StoreOp>(op)) {
39+
Value memref = op.getOperand(1);
40+
if (memrefs.count(memref))
41+
return false;
42+
43+
memrefs.insert(memref);
44+
}
45+
}
46+
return true;
47+
}
48+
49+
namespace {
50+
struct MatchIfElsePass : PassWrapper<MatchIfElsePass, OperationPass<FuncOp>> {
51+
void runOnOperation() override {
52+
FuncOp f = getOperation();
53+
OpBuilder b(f.getContext());
54+
55+
// If there is no store in the target block for a specific memref stored in
56+
// the source block, we will create a dummy load.
57+
auto matchStore = [&](Block *target, Block *source, Location loc) {
58+
llvm::SetVector<Value> memrefs;
59+
60+
for (Operation &op : target->getOperations())
61+
if (isa<memref::StoreOp, mlir::AffineStoreOp>(op))
62+
memrefs.insert(op.getOperand(1));
63+
64+
b.setInsertionPoint(target->getTerminator());
65+
for (Operation &op : source->getOperations()) {
66+
if (!isa<mlir::AffineStoreOp, memref::StoreOp>(op))
67+
continue;
68+
Value memref = op.getOperand(1);
69+
if (memrefs.count(memref)) // has been stored to
70+
continue;
71+
72+
if (auto storeOp = dyn_cast<AffineStoreOp>(op)) {
73+
Value value = b.create<AffineLoadOp>(loc, storeOp.getMemRef(),
74+
storeOp.getAffineMap(),
75+
storeOp.getMapOperands());
76+
b.create<AffineStoreOp>(loc, value, storeOp.getMemRef(),
77+
storeOp.getAffineMap(),
78+
storeOp.getMapOperands());
79+
} else if (auto storeOp = dyn_cast<memref::StoreOp>(op)) {
80+
Value value = b.create<memref::LoadOp>(loc, storeOp.getMemRef(),
81+
storeOp.getIndices());
82+
b.create<memref::StoreOp>(loc, value, storeOp.getMemRef(),
83+
storeOp.indices());
84+
}
85+
}
86+
};
87+
88+
f.walk([&](scf::IfOp ifOp) {
89+
Location loc = ifOp.getLoc();
90+
OpBuilder::InsertionGuard g(b);
91+
92+
// If there is no else block, initialize one with a terminating yield.
93+
if (!ifOp.elseBlock()) {
94+
ifOp.elseRegion().emplaceBlock();
95+
96+
b.setInsertionPointToStart(ifOp.elseBlock());
97+
b.create<scf::YieldOp>(loc);
98+
}
99+
100+
if (!hasSingleStore(ifOp.thenBlock()) ||
101+
!hasSingleStore(ifOp.elseBlock())) {
102+
LLVM_DEBUG(
103+
dbgs()
104+
<< "Skipped if:\n"
105+
<< ifOp
106+
<< "\ndue to there are duplicated stores on the same memref.");
107+
return;
108+
}
109+
110+
matchStore(ifOp.elseBlock(), ifOp.thenBlock(), loc);
111+
matchStore(ifOp.thenBlock(), ifOp.elseBlock(), loc);
112+
113+
LLVM_DEBUG(dbgs() << "Matched else block:\n" << ifOp << "\n\n");
114+
});
115+
}
116+
};
117+
} // namespace
118+
119+
/// ---------------------- LiftStoreOps ------------------------------
120+
121+
namespace {
122+
struct LiftStoreOps : PassWrapper<LiftStoreOps, OperationPass<FuncOp>> {
123+
void runOnOperation() override {}
124+
};
125+
} // namespace
126+
127+
/// ---------------------- FoldSCFIf ----------------------------------
128+
33129
static void foldSCFIf(scf::IfOp ifOp, FuncOp f, OpBuilder &b) {
34130
Location loc = ifOp.getLoc();
35131

36132
LLVM_DEBUG(dbgs() << "Working on ifOp: " << ifOp << "\n\n");
37133

134+
if (!hasSingleStore(ifOp.thenBlock()) ||
135+
(ifOp.elseBlock() && !hasSingleStore(ifOp.elseBlock())))
136+
return;
137+
38138
OpBuilder::InsertionGuard g(b);
39139
b.setInsertionPointAfter(ifOp);
40140

@@ -95,7 +195,10 @@ struct FoldSCFIfPass : PassWrapper<FoldSCFIfPass, OperationPass<FuncOp>> {
95195
} // namespace
96196

97197
void polymer::registerFoldSCFIfPass() {
98-
PassPipelineRegistration<>(
99-
"fold-scf-if", "Fold scf.if into select.",
100-
[](OpPassManager &pm) { pm.addPass(std::make_unique<FoldSCFIfPass>()); });
198+
PassPipelineRegistration<>("fold-scf-if", "Fold scf.if into select.",
199+
[](OpPassManager &pm) {
200+
pm.addPass(std::make_unique<MatchIfElsePass>());
201+
pm.addPass(std::make_unique<LiftStoreOps>());
202+
pm.addPass(std::make_unique<FoldSCFIfPass>());
203+
});
101204
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// RUN: polymer-opt %s -fold-scf-if | FileCheck %s
2+
3+
func @foo(%A: memref<10xf32>, %a: f32, %cond: i1) {
4+
scf.if %cond {
5+
affine.store %a, %A[0] : memref<10xf32>
6+
}
7+
return
8+
}
9+
10+
// CHECK: func @foo

0 commit comments

Comments
 (0)