1- // ===- ParallelUnroll .cpp - Code to perform parallel loop unrolling
1+ // ===- ParallelBanking .cpp - Code to perform memory bnaking in parallel loops
22// --------------------===//
33//
44// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
77//
88// ===----------------------------------------------------------------------===//
99//
10- // This file implements parallel loop unrolling .
10+ // This file implements parallel loop memory banking .
1111//
1212// ===----------------------------------------------------------------------===//
1313
2222#include " mlir/IR/AffineExpr.h"
2323#include " mlir/IR/AffineMap.h"
2424#include " mlir/IR/Builders.h"
25- #include " mlir/IR/IRMapping.h"
2625#include " mlir/IR/Visitors.h"
2726#include " mlir/Support/LLVM.h"
2827#include " mlir/Transforms/DialectConversion.h"
2928#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
3029#include " llvm/ADT/DenseMap.h"
3130#include " llvm/ADT/TypeSwitch.h"
32- #include " llvm/Support/CommandLine.h"
33- #include " llvm/Support/Debug.h"
34- #include " llvm/Support/raw_ostream.h"
3531#include < cassert>
3632
3733namespace mlir {
3834namespace affine {
39- #define GEN_PASS_DEF_AFFINEPARALLELUNROLL
35+ #define GEN_PASS_DEF_AFFINEPARALLELBANKING
4036#include " mlir/Dialect/Affine/Passes.h.inc"
4137} // namespace affine
4238} // namespace mlir
4339
44- #define DEBUG_TYPE " affine-parallel-unroll "
40+ #define DEBUG_TYPE " affine-parallel-banking "
4541
4642using namespace mlir ;
4743using namespace mlir ::affine;
4844
4945namespace {
5046
51- // / Unroll an `affine.parallel` operation by the `unrollFactor` specified in the
52- // / attribute. Evenly splitting `memref`s that are present in the `parallel`
53- // / region into smaller banks.
54- struct ParallelUnroll
55- : public affine::impl::AffineParallelUnrollBase<ParallelUnroll> {
56- const std::function<unsigned (AffineParallelOp)> getUnrollFactor;
57- ParallelUnroll () : getUnrollFactor(nullptr ) {}
58- ParallelUnroll (const ParallelUnroll &other) = default ;
59- explicit ParallelUnroll (std::optional<unsigned > unrollFactor = std::nullopt ,
60- const std::function<unsigned (AffineParallelOp)>
61- &getUnrollFactor = nullptr)
62- : getUnrollFactor(getUnrollFactor) {
63- if (unrollFactor)
64- this ->unrollFactor = *unrollFactor;
47+ // / Partition memories used in `affine.parallel` operation by the
48+ // / `bankingFactor` throughout the program.
49+ struct ParallelBanking
50+ : public affine::impl::AffineParallelBankingBase<ParallelBanking> {
51+ const std::function<unsigned (AffineParallelOp)> getBankingFactor;
52+ ParallelBanking () : getBankingFactor(nullptr ) {}
53+ ParallelBanking (const ParallelBanking &other) = default ;
54+ explicit ParallelBanking (std::optional<unsigned > bankingFactor = std::nullopt ,
55+ const std::function<unsigned (AffineParallelOp)>
56+ &getBankingFactor = nullptr)
57+ : getBankingFactor(getBankingFactor) {
58+ if (bankingFactor)
59+ this ->bankingFactor = *bankingFactor;
6560 }
6661
6762 void getDependentDialects (DialectRegistry ®istry) const override {
6863 registry.insert <mlir::scf::SCFDialect>();
6964 }
7065
7166 void runOnOperation () override ;
72- LogicalResult parallelUnrollByFactor (AffineParallelOp parOp,
73- uint64_t unrollFactor );
67+ LogicalResult parallelBankingByFactor (AffineParallelOp parOp,
68+ uint64_t bankingFactor );
7469
7570private:
7671 // map from original memory definition to newly allocated banks
@@ -108,22 +103,23 @@ MemRefType computeBankedMemRefType(MemRefType originalType,
108103 return newMemRefType;
109104}
110105
111- SmallVector<Value> createBanks (Value originalMem, uint64_t unrollFactor ) {
106+ SmallVector<Value> createBanks (Value originalMem, uint64_t bankingFactor ) {
112107 MemRefType originalMemRefType = cast<MemRefType>(originalMem.getType ());
113108 MemRefType newMemRefType =
114- computeBankedMemRefType (originalMemRefType, unrollFactor );
109+ computeBankedMemRefType (originalMemRefType, bankingFactor );
115110 SmallVector<Value, 4 > banks;
116111 if (auto blockArgMem = dyn_cast<BlockArgument>(originalMem)) {
117112 Block *block = blockArgMem.getOwner ();
118113 unsigned blockArgNum = blockArgMem.getArgNumber ();
119114
120115 SmallVector<Type> banksType;
121- for (unsigned i = 0 ; i < unrollFactor ; ++i) {
116+ for (unsigned i = 0 ; i < bankingFactor ; ++i) {
122117 block->insertArgument (blockArgNum + 1 + i, newMemRefType,
123118 blockArgMem.getLoc ());
124119 }
125120
126- auto blockArgs = block->getArguments ().slice (blockArgNum + 1 , unrollFactor);
121+ auto blockArgs =
122+ block->getArguments ().slice (blockArgNum + 1 , bankingFactor);
127123 banks.append (blockArgs.begin (), blockArgs.end ());
128124 } else {
129125 Operation *originalDef = originalMem.getDefiningOp ();
@@ -132,14 +128,14 @@ SmallVector<Value> createBanks(Value originalMem, uint64_t unrollFactor) {
132128 builder.setInsertionPointAfter (originalDef);
133129 TypeSwitch<Operation *>(originalDef)
134130 .Case <memref::AllocOp>([&](memref::AllocOp allocOp) {
135- for (uint bankCnt = 0 ; bankCnt < unrollFactor ; bankCnt++) {
131+ for (uint bankCnt = 0 ; bankCnt < bankingFactor ; bankCnt++) {
136132 auto bankAllocOp =
137133 builder.create <memref::AllocOp>(loc, newMemRefType);
138134 banks.push_back (bankAllocOp);
139135 }
140136 })
141137 .Case <memref::AllocaOp>([&](memref::AllocaOp allocaOp) {
142- for (uint bankCnt = 0 ; bankCnt < unrollFactor ; bankCnt++) {
138+ for (uint bankCnt = 0 ; bankCnt < bankingFactor ; bankCnt++) {
143139 auto bankAllocaOp =
144140 builder.create <memref::AllocaOp>(loc, newMemRefType);
145141 banks.push_back (bankAllocaOp);
@@ -153,9 +149,9 @@ SmallVector<Value> createBanks(Value originalMem, uint64_t unrollFactor) {
153149}
154150
155151struct BankAffineLoadPattern : public OpRewritePattern <AffineLoadOp> {
156- BankAffineLoadPattern (MLIRContext *context, uint64_t unrollFactor ,
152+ BankAffineLoadPattern (MLIRContext *context, uint64_t bankingFactor ,
157153 DenseMap<Value, SmallVector<Value>> &memoryToBanks)
158- : OpRewritePattern<AffineLoadOp>(context), unrollFactor(unrollFactor ),
154+ : OpRewritePattern<AffineLoadOp>(context), bankingFactor(bankingFactor ),
159155 memoryToBanks (memoryToBanks) {}
160156
161157 LogicalResult matchAndRewrite (AffineLoadOp loadOp,
@@ -164,9 +160,9 @@ struct BankAffineLoadPattern : public OpRewritePattern<AffineLoadOp> {
164160 auto banks = memoryToBanks[loadOp.getMemref ()];
165161 Value loadIndex = loadOp.getIndices ().front ();
166162 auto modMap =
167- AffineMap::get (1 , 0 , {rewriter.getAffineDimExpr (0 ) % unrollFactor });
163+ AffineMap::get (1 , 0 , {rewriter.getAffineDimExpr (0 ) % bankingFactor });
168164 auto divMap = AffineMap::get (
169- 1 , 0 , {rewriter.getAffineDimExpr (0 ).floorDiv (unrollFactor )});
165+ 1 , 0 , {rewriter.getAffineDimExpr (0 ).floorDiv (bankingFactor )});
170166
171167 Value bankIndex = rewriter.create <AffineApplyOp>(
172168 loc, modMap, loadIndex); // assuming one-dim
@@ -175,15 +171,15 @@ struct BankAffineLoadPattern : public OpRewritePattern<AffineLoadOp> {
175171 SmallVector<Type> resultTypes = {loadOp.getResult ().getType ()};
176172
177173 SmallVector<int64_t , 4 > caseValues;
178- for (unsigned i = 0 ; i < unrollFactor ; ++i)
174+ for (unsigned i = 0 ; i < bankingFactor ; ++i)
179175 caseValues.push_back (i);
180176
181177 rewriter.setInsertionPoint (loadOp);
182178 scf::IndexSwitchOp switchOp = rewriter.create <scf::IndexSwitchOp>(
183179 loc, resultTypes, bankIndex, caseValues,
184- /* numRegions=*/ unrollFactor );
180+ /* numRegions=*/ bankingFactor );
185181
186- for (unsigned i = 0 ; i < unrollFactor ; ++i) {
182+ for (unsigned i = 0 ; i < bankingFactor ; ++i) {
187183 Region &caseRegion = switchOp.getCaseRegions ()[i];
188184 rewriter.setInsertionPointToStart (&caseRegion.emplaceBlock ());
189185 Value bankedLoad = rewriter.create <AffineLoadOp>(loc, banks[i], offset);
@@ -206,14 +202,14 @@ struct BankAffineLoadPattern : public OpRewritePattern<AffineLoadOp> {
206202 }
207203
208204private:
209- uint64_t unrollFactor ;
205+ uint64_t bankingFactor ;
210206 DenseMap<Value, SmallVector<Value>> &memoryToBanks;
211207};
212208
213209struct BankAffineStorePattern : public OpRewritePattern <AffineStoreOp> {
214- BankAffineStorePattern (MLIRContext *context, uint64_t unrollFactor ,
210+ BankAffineStorePattern (MLIRContext *context, uint64_t bankingFactor ,
215211 DenseMap<Value, SmallVector<Value>> &memoryToBanks)
216- : OpRewritePattern<AffineStoreOp>(context), unrollFactor(unrollFactor ),
212+ : OpRewritePattern<AffineStoreOp>(context), bankingFactor(bankingFactor ),
217213 memoryToBanks (memoryToBanks) {}
218214
219215 LogicalResult matchAndRewrite (AffineStoreOp storeOp,
@@ -223,9 +219,9 @@ struct BankAffineStorePattern : public OpRewritePattern<AffineStoreOp> {
223219 Value storeIndex = storeOp.getIndices ().front ();
224220
225221 auto modMap =
226- AffineMap::get (1 , 0 , {rewriter.getAffineDimExpr (0 ) % unrollFactor });
222+ AffineMap::get (1 , 0 , {rewriter.getAffineDimExpr (0 ) % bankingFactor });
227223 auto divMap = AffineMap::get (
228- 1 , 0 , {rewriter.getAffineDimExpr (0 ).floorDiv (unrollFactor )});
224+ 1 , 0 , {rewriter.getAffineDimExpr (0 ).floorDiv (bankingFactor )});
229225
230226 Value bankIndex = rewriter.create <AffineApplyOp>(
231227 loc, modMap, storeIndex); // assuming one-dim
@@ -234,15 +230,15 @@ struct BankAffineStorePattern : public OpRewritePattern<AffineStoreOp> {
234230 SmallVector<Type> resultTypes = {};
235231
236232 SmallVector<int64_t , 4 > caseValues;
237- for (unsigned i = 0 ; i < unrollFactor ; ++i)
233+ for (unsigned i = 0 ; i < bankingFactor ; ++i)
238234 caseValues.push_back (i);
239235
240236 rewriter.setInsertionPoint (storeOp);
241237 scf::IndexSwitchOp switchOp = rewriter.create <scf::IndexSwitchOp>(
242238 loc, resultTypes, bankIndex, caseValues,
243- /* numRegions=*/ unrollFactor );
239+ /* numRegions=*/ bankingFactor );
244240
245- for (unsigned i = 0 ; i < unrollFactor ; ++i) {
241+ for (unsigned i = 0 ; i < bankingFactor ; ++i) {
246242 Region &caseRegion = switchOp.getCaseRegions ()[i];
247243 rewriter.setInsertionPointToStart (&caseRegion.emplaceBlock ());
248244 rewriter.create <AffineStoreOp>(loc, storeOp.getValueToStore (), banks[i],
@@ -261,7 +257,7 @@ struct BankAffineStorePattern : public OpRewritePattern<AffineStoreOp> {
261257 }
262258
263259private:
264- uint64_t unrollFactor ;
260+ uint64_t bankingFactor ;
265261 DenseMap<Value, SmallVector<Value>> &memoryToBanks;
266262};
267263
@@ -335,7 +331,7 @@ LogicalResult cleanUpOldMemRefs(DenseSet<Value> &oldMemRefVals) {
335331 return success ();
336332}
337333
338- void ParallelUnroll ::runOnOperation () {
334+ void ParallelBanking ::runOnOperation () {
339335 if (getOperation ().isExternal ()) {
340336 return ;
341337 }
@@ -344,16 +340,16 @@ void ParallelUnroll::runOnOperation() {
344340 DenseSet<Value> memrefsInPar = collectMemRefs (parOp);
345341
346342 for (auto memrefVal : memrefsInPar) {
347- SmallVector<Value> banks = createBanks (memrefVal, unrollFactor );
343+ SmallVector<Value> banks = createBanks (memrefVal, bankingFactor );
348344 memoryToBanks[memrefVal] = banks;
349345 }
350346 });
351347
352348 auto *ctx = &getContext ();
353349 RewritePatternSet patterns (ctx);
354350
355- patterns.add <BankAffineLoadPattern>(ctx, unrollFactor , memoryToBanks);
356- patterns.add <BankAffineStorePattern>(ctx, unrollFactor , memoryToBanks);
351+ patterns.add <BankAffineLoadPattern>(ctx, bankingFactor , memoryToBanks);
352+ patterns.add <BankAffineStorePattern>(ctx, bankingFactor , memoryToBanks);
357353 patterns.add <BankReturnPattern>(ctx, memoryToBanks);
358354
359355 GreedyRewriteConfig config;
@@ -375,10 +371,11 @@ void ParallelUnroll::runOnOperation() {
375371}
376372
377373std::unique_ptr<OperationPass<func::FuncOp>>
378- mlir::affine::createParallelUnrollPass (
379- int unrollFactor,
380- const std::function<unsigned (AffineParallelOp)> &getUnrollFactor) {
381- return std::make_unique<ParallelUnroll>(
382- unrollFactor == -1 ? std::nullopt : std::optional<unsigned >(unrollFactor),
383- getUnrollFactor);
374+ mlir::affine::createParallelBankingPass (
375+ int bankingFactor,
376+ const std::function<unsigned (AffineParallelOp)> &getBankingFactor) {
377+ return std::make_unique<ParallelBanking>(
378+ bankingFactor == -1 ? std::nullopt
379+ : std::optional<unsigned >(bankingFactor),
380+ getBankingFactor);
384381}
0 commit comments