8
8
#include " mlir/Analysis/Utils.h"
9
9
#include " mlir/Dialect/Affine/IR/AffineOps.h"
10
10
#include " mlir/Dialect/Affine/IR/AffineValueMap.h"
11
+ #include " mlir/Dialect/MemRef/IR/MemRef.h"
11
12
#include " mlir/Dialect/SCF/SCF.h"
12
13
#include " mlir/IR/BlockAndValueMapping.h"
13
14
#include " mlir/IR/Builders.h"
@@ -30,11 +31,110 @@ using namespace llvm;
30
31
31
32
#define DEBUG_TYPE " fold-scf-if"
32
33
34
+ static bool hasSingleStore (Block *block) {
35
+ llvm::SetVector<Value> memrefs;
36
+
37
+ for (Operation &op : block->getOperations ()) {
38
+ if (isa<mlir::AffineStoreOp, memref::StoreOp>(op)) {
39
+ Value memref = op.getOperand (1 );
40
+ if (memrefs.count (memref))
41
+ return false ;
42
+
43
+ memrefs.insert (memref);
44
+ }
45
+ }
46
+ return true ;
47
+ }
48
+
49
+ namespace {
50
+ struct MatchIfElsePass : PassWrapper<MatchIfElsePass, OperationPass<FuncOp>> {
51
+ void runOnOperation () override {
52
+ FuncOp f = getOperation ();
53
+ OpBuilder b (f.getContext ());
54
+
55
+ // If there is no store in the target block for a specific memref stored in
56
+ // the source block, we will create a dummy load.
57
+ auto matchStore = [&](Block *target, Block *source, Location loc) {
58
+ llvm::SetVector<Value> memrefs;
59
+
60
+ for (Operation &op : target->getOperations ())
61
+ if (isa<memref::StoreOp, mlir::AffineStoreOp>(op))
62
+ memrefs.insert (op.getOperand (1 ));
63
+
64
+ b.setInsertionPoint (target->getTerminator ());
65
+ for (Operation &op : source->getOperations ()) {
66
+ if (!isa<mlir::AffineStoreOp, memref::StoreOp>(op))
67
+ continue ;
68
+ Value memref = op.getOperand (1 );
69
+ if (memrefs.count (memref)) // has been stored to
70
+ continue ;
71
+
72
+ if (auto storeOp = dyn_cast<AffineStoreOp>(op)) {
73
+ Value value = b.create <AffineLoadOp>(loc, storeOp.getMemRef (),
74
+ storeOp.getAffineMap (),
75
+ storeOp.getMapOperands ());
76
+ b.create <AffineStoreOp>(loc, value, storeOp.getMemRef (),
77
+ storeOp.getAffineMap (),
78
+ storeOp.getMapOperands ());
79
+ } else if (auto storeOp = dyn_cast<memref::StoreOp>(op)) {
80
+ Value value = b.create <memref::LoadOp>(loc, storeOp.getMemRef (),
81
+ storeOp.getIndices ());
82
+ b.create <memref::StoreOp>(loc, value, storeOp.getMemRef (),
83
+ storeOp.indices ());
84
+ }
85
+ }
86
+ };
87
+
88
+ f.walk ([&](scf::IfOp ifOp) {
89
+ Location loc = ifOp.getLoc ();
90
+ OpBuilder::InsertionGuard g (b);
91
+
92
+ // If there is no else block, initialize one with a terminating yield.
93
+ if (!ifOp.elseBlock ()) {
94
+ ifOp.elseRegion ().emplaceBlock ();
95
+
96
+ b.setInsertionPointToStart (ifOp.elseBlock ());
97
+ b.create <scf::YieldOp>(loc);
98
+ }
99
+
100
+ if (!hasSingleStore (ifOp.thenBlock ()) ||
101
+ !hasSingleStore (ifOp.elseBlock ())) {
102
+ LLVM_DEBUG (
103
+ dbgs ()
104
+ << " Skipped if:\n "
105
+ << ifOp
106
+ << " \n due to there are duplicated stores on the same memref." );
107
+ return ;
108
+ }
109
+
110
+ matchStore (ifOp.elseBlock (), ifOp.thenBlock (), loc);
111
+ matchStore (ifOp.thenBlock (), ifOp.elseBlock (), loc);
112
+
113
+ LLVM_DEBUG (dbgs () << " Matched else block:\n " << ifOp << " \n\n " );
114
+ });
115
+ }
116
+ };
117
+ } // namespace
118
+
119
+ // / ---------------------- LiftStoreOps ------------------------------
120
+
121
+ namespace {
122
+ struct LiftStoreOps : PassWrapper<LiftStoreOps, OperationPass<FuncOp>> {
123
+ void runOnOperation () override {}
124
+ };
125
+ } // namespace
126
+
127
+ // / ---------------------- FoldSCFIf ----------------------------------
128
+
33
129
static void foldSCFIf (scf::IfOp ifOp, FuncOp f, OpBuilder &b) {
34
130
Location loc = ifOp.getLoc ();
35
131
36
132
LLVM_DEBUG (dbgs () << " Working on ifOp: " << ifOp << " \n\n " );
37
133
134
+ if (!hasSingleStore (ifOp.thenBlock ()) ||
135
+ (ifOp.elseBlock () && !hasSingleStore (ifOp.elseBlock ())))
136
+ return ;
137
+
38
138
OpBuilder::InsertionGuard g (b);
39
139
b.setInsertionPointAfter (ifOp);
40
140
@@ -95,7 +195,10 @@ struct FoldSCFIfPass : PassWrapper<FoldSCFIfPass, OperationPass<FuncOp>> {
95
195
} // namespace
96
196
97
197
void polymer::registerFoldSCFIfPass () {
98
- PassPipelineRegistration<>(
99
- " fold-scf-if" , " Fold scf.if into select." ,
100
- [](OpPassManager &pm) { pm.addPass (std::make_unique<FoldSCFIfPass>()); });
198
+ PassPipelineRegistration<>(" fold-scf-if" , " Fold scf.if into select." ,
199
+ [](OpPassManager &pm) {
200
+ pm.addPass (std::make_unique<MatchIfElsePass>());
201
+ pm.addPass (std::make_unique<LiftStoreOps>());
202
+ pm.addPass (std::make_unique<FoldSCFIfPass>());
203
+ });
101
204
}
0 commit comments