Skip to content

Commit 9976553

Browse files
committed
clean up old memrefs
1 parent b65b2df commit 9976553

File tree

1 file changed

+37
-42
lines changed

1 file changed

+37
-42
lines changed

mlir/lib/Dialect/Affine/Transforms/ParallelUnroll.cpp

Lines changed: 37 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ MemRefType computeBankedMemRefType(MemRefType originalType,
100100
SmallVector<int64_t, 4> newShape(originalShape.begin(), originalShape.end());
101101
assert(newShape.front() % bankingFactor == 0 &&
102102
"memref shape must be divided by the banking factor");
103-
// Now assuming banking the last dimension
104103
newShape.front() /= bankingFactor;
105104
MemRefType newMemRefType =
106105
MemRefType::get(newShape, originalType.getElementType(),
@@ -153,17 +152,6 @@ SmallVector<Value> createBanks(Value originalMem, uint64_t unrollFactor) {
153152
return banks;
154153
}
155154

156-
Value computeIntraBankingOffset(OpBuilder &builder, Location loc, Value address,
157-
uint availableBanks) {
158-
Value availBanksVal =
159-
builder
160-
.create<arith::ConstantOp>(loc, builder.getIndexAttr(availableBanks))
161-
.getResult();
162-
Value offset =
163-
builder.create<arith::DivUIOp>(loc, address, availBanksVal).getResult();
164-
return offset;
165-
}
166-
167155
struct BankAffineLoadPattern : public OpRewritePattern<AffineLoadOp> {
168156
BankAffineLoadPattern(MLIRContext *context, uint64_t unrollFactor,
169157
DenseMap<Value, SmallVector<Value>> &memoryToBanks)
@@ -172,7 +160,6 @@ struct BankAffineLoadPattern : public OpRewritePattern<AffineLoadOp> {
172160

173161
LogicalResult matchAndRewrite(AffineLoadOp loadOp,
174162
PatternRewriter &rewriter) const override {
175-
llvm::errs() << "load pattern matchAndRewrite\n";
176163
Location loc = loadOp.getLoc();
177164
auto banks = memoryToBanks[loadOp.getMemref()];
178165
Value loadIndex = loadOp.getIndices().front();
@@ -231,7 +218,6 @@ struct BankAffineStorePattern : public OpRewritePattern<AffineStoreOp> {
231218

232219
LogicalResult matchAndRewrite(AffineStoreOp storeOp,
233220
PatternRewriter &rewriter) const override {
234-
llvm::errs() << "store pattern matchAndRewrite\n";
235221
Location loc = storeOp.getLoc();
236222
auto banks = memoryToBanks[storeOp.getMemref()];
237223
Value storeIndex = storeOp.getIndices().front();
@@ -300,6 +286,7 @@ struct BankReturnPattern : public OpRewritePattern<func::ReturnOp> {
300286
auto banks = memoryToBanks[operand];
301287
newReturnOperands.append(banks.begin(), banks.end());
302288
}
289+
303290
func::FuncOp funcOp = returnOp.getParentOp();
304291
rewriter.setInsertionPointToEnd(&funcOp.getBlocks().front());
305292
auto newReturnOp =
@@ -310,16 +297,44 @@ struct BankReturnPattern : public OpRewritePattern<func::ReturnOp> {
310297
funcOp.getFunctionType().getInputs(), newReturnType);
311298
funcOp.setType(newFuncType);
312299

313-
if (allOrigMemsUsedByReturn) {
300+
if (allOrigMemsUsedByReturn)
314301
rewriter.replaceOp(returnOp, newReturnOp);
315-
}
302+
316303
return success();
317304
}
318305

319306
private:
320307
DenseMap<Value, SmallVector<Value>> &memoryToBanks;
321308
};
322309

310+
LogicalResult cleanUpOldMemRefs(DenseSet<Value> &oldMemRefVals) {
311+
DenseSet<func::FuncOp> funcsToModify;
312+
for (auto &memrefVal : oldMemRefVals) {
313+
if (!memrefVal.use_empty())
314+
continue;
315+
if (auto blockArg = dyn_cast<BlockArgument>(memrefVal)) {
316+
Block *block = blockArg.getOwner();
317+
block->eraseArgument(blockArg.getArgNumber());
318+
if (auto funcOp = dyn_cast<func::FuncOp>(block->getParentOp()))
319+
funcsToModify.insert(funcOp);
320+
} else
321+
memrefVal.getDefiningOp()->erase();
322+
}
323+
324+
// Modify the function type accordingly
325+
for (auto funcOp : funcsToModify) {
326+
SmallVector<Type, 4> newArgTypes;
327+
for (BlockArgument arg : funcOp.getArguments()) {
328+
newArgTypes.push_back(arg.getType());
329+
}
330+
FunctionType newFuncType =
331+
FunctionType::get(funcOp.getContext(), newArgTypes,
332+
funcOp.getFunctionType().getResults());
333+
funcOp.setType(newFuncType);
334+
}
335+
return success();
336+
}
337+
323338
void ParallelUnroll::runOnOperation() {
324339
if (getOperation().isExternal()) {
325340
return;
@@ -335,7 +350,6 @@ void ParallelUnroll::runOnOperation() {
335350
});
336351

337352
auto *ctx = &getContext();
338-
339353
RewritePatternSet patterns(ctx);
340354

341355
patterns.add<BankAffineLoadPattern>(ctx, unrollFactor, memoryToBanks);
@@ -350,33 +364,14 @@ void ParallelUnroll::runOnOperation() {
350364
signalPassFailure();
351365
}
352366

353-
DenseSet<Block *> blocksToModify;
354-
for (auto &[memrefVal, banks] : memoryToBanks) {
355-
if (memrefVal.use_empty()) {
356-
if (auto blockArg = dyn_cast<BlockArgument>(memrefVal)) {
357-
blockArg.getOwner()->eraseArgument(blockArg.getArgNumber());
358-
blocksToModify.insert(blockArg.getOwner());
359-
} else {
360-
memrefVal.getDefiningOp()->erase();
361-
}
362-
}
363-
}
367+
// Clean up the old memref values
368+
DenseSet<Value> oldMemRefVals;
369+
for (const auto &pair : memoryToBanks)
370+
oldMemRefVals.insert(pair.first);
364371

365-
for (auto *block : blocksToModify) {
366-
if (!isa<func::FuncOp>(block->getParentOp()))
367-
continue;
368-
func::FuncOp funcOp = cast<func::FuncOp>(block->getParentOp());
369-
SmallVector<Type, 4> newArgTypes;
370-
for (BlockArgument arg : funcOp.getArguments()) {
371-
newArgTypes.push_back(arg.getType());
372-
}
373-
FunctionType newFuncType =
374-
FunctionType::get(funcOp.getContext(), newArgTypes,
375-
funcOp.getFunctionType().getResults());
376-
funcOp.setType(newFuncType);
372+
if (failed(cleanUpOldMemRefs(oldMemRefVals))) {
373+
signalPassFailure();
377374
}
378-
379-
getOperation().dump();
380375
}
381376

382377
std::unique_ptr<OperationPass<func::FuncOp>>

0 commit comments

Comments
 (0)