Skip to content

Commit 295b090

Browse files
kumasentoivanradanov
authored andcommitted
[FoldSCFIf] lift store out of scf.if
1 parent 994c12d commit 295b090

File tree

2 files changed

+175
-2
lines changed

2 files changed

+175
-2
lines changed

tools/polymer/lib/Transforms/FoldSCFIf.cc

Lines changed: 171 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,19 @@ static bool hasSingleStore(Block *block) {
4040
if (memrefs.count(memref))
4141
return false;
4242

43+
// The indices should be defined above the current block.
44+
if (auto storeOp = dyn_cast<mlir::AffineStoreOp>(op)) {
45+
if (any_of(storeOp.getMapOperands(), [&](Value operand) {
46+
return operand.getParentBlock() == block;
47+
}))
48+
return false;
49+
} else if (auto storeOp = dyn_cast<memref::StoreOp>(op)) {
50+
if (any_of(storeOp.getIndices(), [&](Value operand) {
51+
return operand.getParentBlock() == block;
52+
}))
53+
return false;
54+
}
55+
4356
memrefs.insert(memref);
4457
}
4558
}
@@ -118,9 +131,166 @@ struct MatchIfElsePass : PassWrapper<MatchIfElsePass, OperationPass<FuncOp>> {
118131

119132
/// ---------------------- LiftStoreOps ------------------------------
120133

134+
static bool hasMatchingStores(ArrayRef<Block *> blocks) {
135+
if (blocks.size() <= 1)
136+
return true;
137+
138+
llvm::SetVector<Value> setUnion;
139+
140+
for (Block *block : blocks.drop_front()) {
141+
llvm::SetVector<Value> memrefs;
142+
143+
for (Operation &op : block->getOperations())
144+
if (isa<memref::StoreOp, mlir::AffineStoreOp>(op)) {
145+
Value memref = op.getOperand(1);
146+
assert(!memrefs.count(memref) &&
147+
"Should only apply on blocks that contain single store to each "
148+
"memref.");
149+
150+
memrefs.insert(op.getOperand(1));
151+
}
152+
153+
bool wasEmpty = setUnion.empty();
154+
if (!wasEmpty && setUnion.set_union(memrefs))
155+
return false;
156+
}
157+
158+
return true;
159+
}
160+
161+
namespace {
162+
struct MemRefStoreInfo {
163+
unsigned index;
164+
Type type;
165+
Operation *source;
166+
SmallVector<Value> operands;
167+
};
168+
} // namespace
169+
170+
static void
171+
getMemRefStoreInfo(Block *block,
172+
SmallDenseMap<Value, MemRefStoreInfo> &storeInfo) {
173+
unsigned ord = 0;
174+
for (Operation &op : block->getOperations())
175+
if (isa<memref::StoreOp, mlir::AffineStoreOp>(op)) {
176+
MemRefStoreInfo info;
177+
info.index = ord++;
178+
info.type = op.getOperand(0).getType();
179+
info.source = &op;
180+
181+
if (auto storeOp = dyn_cast<memref::StoreOp>(op))
182+
info.operands = storeOp.getIndices();
183+
else if (auto storeOp = dyn_cast<mlir::AffineStoreOp>(op))
184+
info.operands = storeOp.getMapOperands();
185+
186+
storeInfo[op.getOperand(1)] = info;
187+
}
188+
}
189+
190+
static LogicalResult liftStoreOps(scf::IfOp ifOp, FuncOp f, OpBuilder &b) {
191+
Location loc = ifOp.getLoc();
192+
193+
if (!hasMatchingStores({ifOp.thenBlock(), ifOp.elseBlock()}))
194+
return failure();
195+
196+
SmallDenseMap<Value, MemRefStoreInfo> storeInfo;
197+
getMemRefStoreInfo(ifOp.thenBlock(), storeInfo);
198+
199+
SmallVector<Type> storeTypes(storeInfo.size());
200+
for (auto &info : storeInfo)
201+
storeTypes[info.second.index] = info.second.type;
202+
203+
// No need to process further.
204+
if (storeInfo.empty())
205+
return failure();
206+
207+
OpBuilder::InsertionGuard g(b);
208+
b.setInsertionPointAfter(ifOp);
209+
210+
SmallVector<Type> resultTypes(ifOp.getResultTypes());
211+
resultTypes.append(storeTypes);
212+
213+
scf::IfOp newIfOp = b.create<scf::IfOp>(loc, resultTypes, ifOp.condition(),
214+
/*withElseRegion=*/true);
215+
216+
auto cloneBlock = [&](Block *target, Block *source) {
217+
BlockAndValueMapping vmap;
218+
219+
scf::YieldOp yieldOp = cast<scf::YieldOp>(source->getTerminator());
220+
unsigned numExistingResults = yieldOp.getNumOperands();
221+
SmallVector<Value> results(numExistingResults + storeInfo.size());
222+
223+
OpBuilder::InsertionGuard g(b);
224+
b.setInsertionPointToStart(target);
225+
226+
for (Operation &op : source->getOperations()) {
227+
if (isa<memref::StoreOp, mlir::AffineStoreOp>(op)) {
228+
Value memref = op.getOperand(1);
229+
Value toStore = op.getOperand(0);
230+
results[storeInfo[memref].index + numExistingResults] =
231+
vmap.lookupOrDefault(toStore);
232+
} else if (!isa<scf::YieldOp>(op)) {
233+
b.clone(op, vmap);
234+
}
235+
}
236+
237+
for (auto operand : enumerate(yieldOp.getOperands()))
238+
results[operand.index()] = vmap.lookupOrDefault(operand.value());
239+
240+
b.create<scf::YieldOp>(loc, results);
241+
};
242+
243+
cloneBlock(newIfOp.thenBlock(), ifOp.thenBlock());
244+
cloneBlock(newIfOp.elseBlock(), ifOp.elseBlock());
245+
246+
b.setInsertionPointAfter(newIfOp);
247+
for (auto &p : storeInfo) {
248+
Value memref;
249+
MemRefStoreInfo info;
250+
std::tie(memref, info) = p;
251+
252+
if (auto storeOp = dyn_cast<mlir::AffineStoreOp>(info.source))
253+
b.create<mlir::AffineStoreOp>(
254+
loc, newIfOp.getResult(ifOp.getNumResults() + info.index), memref,
255+
storeOp.getAffineMap(), info.operands);
256+
else if (auto storeOp = dyn_cast<memref::StoreOp>(info.source))
257+
b.create<memref::StoreOp>(
258+
loc, newIfOp.getResult(ifOp.getNumResults() + info.index), memref,
259+
info.operands);
260+
}
261+
262+
ifOp.erase();
263+
264+
return success();
265+
}
266+
267+
static bool processLiftStoreOps(FuncOp f, OpBuilder &b) {
268+
bool changed = false;
269+
270+
f.walk([&](scf::IfOp ifOp) {
271+
if (!hasSingleStore(ifOp.thenBlock()) ||
272+
(ifOp.elseBlock() && !hasSingleStore(ifOp.elseBlock())))
273+
return;
274+
if (failed(liftStoreOps(ifOp, f, b)))
275+
return;
276+
277+
changed = true;
278+
});
279+
280+
return changed;
281+
}
282+
121283
namespace {
122284
struct LiftStoreOps : PassWrapper<LiftStoreOps, OperationPass<FuncOp>> {
123-
void runOnOperation() override {}
285+
void runOnOperation() override {
286+
FuncOp f = getOperation();
287+
OpBuilder b(f.getContext());
288+
289+
// For each scf.if, see if it has single store for each memref on each
290+
// branch.
291+
while (processLiftStoreOps(f, b))
292+
;
293+
}
124294
};
125295
} // namespace
126296

tools/polymer/test/polymer-opt/FoldSCFIf/match-store.mlir

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,7 @@ func @foo(%A: memref<10xf32>, %a: f32, %cond: i1) {
77
return
88
}
99

10-
// CHECK: func @foo
10+
// CHECK: func @foo(%[[A:.*]]: memref<10xf32>, %[[a:.*]]: f32, %[[cond:.*]]: i1)
11+
// CHECK-NEXT: %[[v0:.*]] = affine.load %[[A]][0] : memref<10xf32>
12+
// CHECK-NEXT: %[[v1:.*]] = select %[[cond]], %[[a]], %[[v0]] : f32
13+
// CHECK-NEXT: affine.store %[[v1]], %[[A]][0] : memref<10xf32>

0 commit comments

Comments
 (0)