@@ -100,7 +100,6 @@ MemRefType computeBankedMemRefType(MemRefType originalType,
100100 SmallVector<int64_t , 4 > newShape (originalShape.begin (), originalShape.end ());
101101 assert (newShape.front () % bankingFactor == 0 &&
102102 " memref shape must be divided by the banking factor" );
103- // Now assuming banking the last dimension
104103 newShape.front () /= bankingFactor;
105104 MemRefType newMemRefType =
106105 MemRefType::get (newShape, originalType.getElementType (),
@@ -153,17 +152,6 @@ SmallVector<Value> createBanks(Value originalMem, uint64_t unrollFactor) {
153152 return banks;
154153}
155154
156- Value computeIntraBankingOffset (OpBuilder &builder, Location loc, Value address,
157- uint availableBanks) {
158- Value availBanksVal =
159- builder
160- .create <arith::ConstantOp>(loc, builder.getIndexAttr (availableBanks))
161- .getResult ();
162- Value offset =
163- builder.create <arith::DivUIOp>(loc, address, availBanksVal).getResult ();
164- return offset;
165- }
166-
167155struct BankAffineLoadPattern : public OpRewritePattern <AffineLoadOp> {
168156 BankAffineLoadPattern (MLIRContext *context, uint64_t unrollFactor,
169157 DenseMap<Value, SmallVector<Value>> &memoryToBanks)
@@ -172,7 +160,6 @@ struct BankAffineLoadPattern : public OpRewritePattern<AffineLoadOp> {
172160
173161 LogicalResult matchAndRewrite (AffineLoadOp loadOp,
174162 PatternRewriter &rewriter) const override {
175- llvm::errs () << " load pattern matchAndRewrite\n " ;
176163 Location loc = loadOp.getLoc ();
177164 auto banks = memoryToBanks[loadOp.getMemref ()];
178165 Value loadIndex = loadOp.getIndices ().front ();
@@ -231,7 +218,6 @@ struct BankAffineStorePattern : public OpRewritePattern<AffineStoreOp> {
231218
232219 LogicalResult matchAndRewrite (AffineStoreOp storeOp,
233220 PatternRewriter &rewriter) const override {
234- llvm::errs () << " store pattern matchAndRewrite\n " ;
235221 Location loc = storeOp.getLoc ();
236222 auto banks = memoryToBanks[storeOp.getMemref ()];
237223 Value storeIndex = storeOp.getIndices ().front ();
@@ -300,6 +286,7 @@ struct BankReturnPattern : public OpRewritePattern<func::ReturnOp> {
300286 auto banks = memoryToBanks[operand];
301287 newReturnOperands.append (banks.begin (), banks.end ());
302288 }
289+
303290 func::FuncOp funcOp = returnOp.getParentOp ();
304291 rewriter.setInsertionPointToEnd (&funcOp.getBlocks ().front ());
305292 auto newReturnOp =
@@ -310,16 +297,44 @@ struct BankReturnPattern : public OpRewritePattern<func::ReturnOp> {
310297 funcOp.getFunctionType ().getInputs (), newReturnType);
311298 funcOp.setType (newFuncType);
312299
313- if (allOrigMemsUsedByReturn) {
300+ if (allOrigMemsUsedByReturn)
314301 rewriter.replaceOp (returnOp, newReturnOp);
315- }
302+
316303 return success ();
317304 }
318305
319306private:
320307 DenseMap<Value, SmallVector<Value>> &memoryToBanks;
321308};
322309
310+ LogicalResult cleanUpOldMemRefs (DenseSet<Value> &oldMemRefVals) {
311+ DenseSet<func::FuncOp> funcsToModify;
312+ for (auto &memrefVal : oldMemRefVals) {
313+ if (!memrefVal.use_empty ())
314+ continue ;
315+ if (auto blockArg = dyn_cast<BlockArgument>(memrefVal)) {
316+ Block *block = blockArg.getOwner ();
317+ block->eraseArgument (blockArg.getArgNumber ());
318+ if (auto funcOp = dyn_cast<func::FuncOp>(block->getParentOp ()))
319+ funcsToModify.insert (funcOp);
320+ } else
321+ memrefVal.getDefiningOp ()->erase ();
322+ }
323+
324+ // Modify the function type accordingly
325+ for (auto funcOp : funcsToModify) {
326+ SmallVector<Type, 4 > newArgTypes;
327+ for (BlockArgument arg : funcOp.getArguments ()) {
328+ newArgTypes.push_back (arg.getType ());
329+ }
330+ FunctionType newFuncType =
331+ FunctionType::get (funcOp.getContext (), newArgTypes,
332+ funcOp.getFunctionType ().getResults ());
333+ funcOp.setType (newFuncType);
334+ }
335+ return success ();
336+ }
337+
323338void ParallelUnroll::runOnOperation () {
324339 if (getOperation ().isExternal ()) {
325340 return ;
@@ -335,7 +350,6 @@ void ParallelUnroll::runOnOperation() {
335350 });
336351
337352 auto *ctx = &getContext ();
338-
339353 RewritePatternSet patterns (ctx);
340354
341355 patterns.add <BankAffineLoadPattern>(ctx, unrollFactor, memoryToBanks);
@@ -350,33 +364,14 @@ void ParallelUnroll::runOnOperation() {
350364 signalPassFailure ();
351365 }
352366
353- DenseSet<Block *> blocksToModify;
354- for (auto &[memrefVal, banks] : memoryToBanks) {
355- if (memrefVal.use_empty ()) {
356- if (auto blockArg = dyn_cast<BlockArgument>(memrefVal)) {
357- blockArg.getOwner ()->eraseArgument (blockArg.getArgNumber ());
358- blocksToModify.insert (blockArg.getOwner ());
359- } else {
360- memrefVal.getDefiningOp ()->erase ();
361- }
362- }
363- }
367+ // Clean up the old memref values
368+ DenseSet<Value> oldMemRefVals;
369+ for (const auto &pair : memoryToBanks)
370+ oldMemRefVals.insert (pair.first );
364371
365- for (auto *block : blocksToModify) {
366- if (!isa<func::FuncOp>(block->getParentOp ()))
367- continue ;
368- func::FuncOp funcOp = cast<func::FuncOp>(block->getParentOp ());
369- SmallVector<Type, 4 > newArgTypes;
370- for (BlockArgument arg : funcOp.getArguments ()) {
371- newArgTypes.push_back (arg.getType ());
372- }
373- FunctionType newFuncType =
374- FunctionType::get (funcOp.getContext (), newArgTypes,
375- funcOp.getFunctionType ().getResults ());
376- funcOp.setType (newFuncType);
372+ if (failed (cleanUpOldMemRefs (oldMemRefVals))) {
373+ signalPassFailure ();
377374 }
378-
379- getOperation ().dump ();
380375}
381376
382377std::unique_ptr<OperationPass<func::FuncOp>>
0 commit comments