@@ -40,6 +40,19 @@ static bool hasSingleStore(Block *block) {
40
40
if (memrefs.count (memref))
41
41
return false ;
42
42
43
+ // The indices should be defined above the current block.
44
+ if (auto storeOp = dyn_cast<mlir::AffineStoreOp>(op)) {
45
+ if (any_of (storeOp.getMapOperands (), [&](Value operand) {
46
+ return operand.getParentBlock () == block;
47
+ }))
48
+ return false ;
49
+ } else if (auto storeOp = dyn_cast<memref::StoreOp>(op)) {
50
+ if (any_of (storeOp.getIndices (), [&](Value operand) {
51
+ return operand.getParentBlock () == block;
52
+ }))
53
+ return false ;
54
+ }
55
+
43
56
memrefs.insert (memref);
44
57
}
45
58
}
@@ -118,9 +131,166 @@ struct MatchIfElsePass : PassWrapper<MatchIfElsePass, OperationPass<FuncOp>> {
118
131
119
132
// / ---------------------- LiftStoreOps ------------------------------
120
133
134
+ static bool hasMatchingStores (ArrayRef<Block *> blocks) {
135
+ if (blocks.size () <= 1 )
136
+ return true ;
137
+
138
+ llvm::SetVector<Value> setUnion;
139
+
140
+ for (Block *block : blocks.drop_front ()) {
141
+ llvm::SetVector<Value> memrefs;
142
+
143
+ for (Operation &op : block->getOperations ())
144
+ if (isa<memref::StoreOp, mlir::AffineStoreOp>(op)) {
145
+ Value memref = op.getOperand (1 );
146
+ assert (!memrefs.count (memref) &&
147
+ " Should only apply on blocks that contain single store to each "
148
+ " memref." );
149
+
150
+ memrefs.insert (op.getOperand (1 ));
151
+ }
152
+
153
+ bool wasEmpty = setUnion.empty ();
154
+ if (!wasEmpty && setUnion.set_union (memrefs))
155
+ return false ;
156
+ }
157
+
158
+ return true ;
159
+ }
160
+
161
+ namespace {
162
+ struct MemRefStoreInfo {
163
+ unsigned index;
164
+ Type type;
165
+ Operation *source;
166
+ SmallVector<Value> operands;
167
+ };
168
+ } // namespace
169
+
170
+ static void
171
+ getMemRefStoreInfo (Block *block,
172
+ SmallDenseMap<Value, MemRefStoreInfo> &storeInfo) {
173
+ unsigned ord = 0 ;
174
+ for (Operation &op : block->getOperations ())
175
+ if (isa<memref::StoreOp, mlir::AffineStoreOp>(op)) {
176
+ MemRefStoreInfo info;
177
+ info.index = ord++;
178
+ info.type = op.getOperand (0 ).getType ();
179
+ info.source = &op;
180
+
181
+ if (auto storeOp = dyn_cast<memref::StoreOp>(op))
182
+ info.operands = storeOp.getIndices ();
183
+ else if (auto storeOp = dyn_cast<mlir::AffineStoreOp>(op))
184
+ info.operands = storeOp.getMapOperands ();
185
+
186
+ storeInfo[op.getOperand (1 )] = info;
187
+ }
188
+ }
189
+
190
+ static LogicalResult liftStoreOps (scf::IfOp ifOp, FuncOp f, OpBuilder &b) {
191
+ Location loc = ifOp.getLoc ();
192
+
193
+ if (!hasMatchingStores ({ifOp.thenBlock (), ifOp.elseBlock ()}))
194
+ return failure ();
195
+
196
+ SmallDenseMap<Value, MemRefStoreInfo> storeInfo;
197
+ getMemRefStoreInfo (ifOp.thenBlock (), storeInfo);
198
+
199
+ SmallVector<Type> storeTypes (storeInfo.size ());
200
+ for (auto &info : storeInfo)
201
+ storeTypes[info.second .index ] = info.second .type ;
202
+
203
+ // No need to process further.
204
+ if (storeInfo.empty ())
205
+ return failure ();
206
+
207
+ OpBuilder::InsertionGuard g (b);
208
+ b.setInsertionPointAfter (ifOp);
209
+
210
+ SmallVector<Type> resultTypes (ifOp.getResultTypes ());
211
+ resultTypes.append (storeTypes);
212
+
213
+ scf::IfOp newIfOp = b.create <scf::IfOp>(loc, resultTypes, ifOp.condition (),
214
+ /* withElseRegion=*/ true );
215
+
216
+ auto cloneBlock = [&](Block *target, Block *source) {
217
+ BlockAndValueMapping vmap;
218
+
219
+ scf::YieldOp yieldOp = cast<scf::YieldOp>(source->getTerminator ());
220
+ unsigned numExistingResults = yieldOp.getNumOperands ();
221
+ SmallVector<Value> results (numExistingResults + storeInfo.size ());
222
+
223
+ OpBuilder::InsertionGuard g (b);
224
+ b.setInsertionPointToStart (target);
225
+
226
+ for (Operation &op : source->getOperations ()) {
227
+ if (isa<memref::StoreOp, mlir::AffineStoreOp>(op)) {
228
+ Value memref = op.getOperand (1 );
229
+ Value toStore = op.getOperand (0 );
230
+ results[storeInfo[memref].index + numExistingResults] =
231
+ vmap.lookupOrDefault (toStore);
232
+ } else if (!isa<scf::YieldOp>(op)) {
233
+ b.clone (op, vmap);
234
+ }
235
+ }
236
+
237
+ for (auto operand : enumerate(yieldOp.getOperands ()))
238
+ results[operand.index ()] = vmap.lookupOrDefault (operand.value ());
239
+
240
+ b.create <scf::YieldOp>(loc, results);
241
+ };
242
+
243
+ cloneBlock (newIfOp.thenBlock (), ifOp.thenBlock ());
244
+ cloneBlock (newIfOp.elseBlock (), ifOp.elseBlock ());
245
+
246
+ b.setInsertionPointAfter (newIfOp);
247
+ for (auto &p : storeInfo) {
248
+ Value memref;
249
+ MemRefStoreInfo info;
250
+ std::tie (memref, info) = p;
251
+
252
+ if (auto storeOp = dyn_cast<mlir::AffineStoreOp>(info.source ))
253
+ b.create <mlir::AffineStoreOp>(
254
+ loc, newIfOp.getResult (ifOp.getNumResults () + info.index ), memref,
255
+ storeOp.getAffineMap (), info.operands );
256
+ else if (auto storeOp = dyn_cast<memref::StoreOp>(info.source ))
257
+ b.create <memref::StoreOp>(
258
+ loc, newIfOp.getResult (ifOp.getNumResults () + info.index ), memref,
259
+ info.operands );
260
+ }
261
+
262
+ ifOp.erase ();
263
+
264
+ return success ();
265
+ }
266
+
267
+ static bool processLiftStoreOps (FuncOp f, OpBuilder &b) {
268
+ bool changed = false ;
269
+
270
+ f.walk ([&](scf::IfOp ifOp) {
271
+ if (!hasSingleStore (ifOp.thenBlock ()) ||
272
+ (ifOp.elseBlock () && !hasSingleStore (ifOp.elseBlock ())))
273
+ return ;
274
+ if (failed (liftStoreOps (ifOp, f, b)))
275
+ return ;
276
+
277
+ changed = true ;
278
+ });
279
+
280
+ return changed;
281
+ }
282
+
121
283
namespace {
122
284
struct LiftStoreOps : PassWrapper<LiftStoreOps, OperationPass<FuncOp>> {
123
- void runOnOperation () override {}
285
+ void runOnOperation () override {
286
+ FuncOp f = getOperation ();
287
+ OpBuilder b (f.getContext ());
288
+
289
+ // For each scf.if, see if it has single store for each memref on each
290
+ // branch.
291
+ while (processLiftStoreOps (f, b))
292
+ ;
293
+ }
124
294
};
125
295
} // namespace
126
296
0 commit comments