2525#include " mlir/IR/IRMapping.h"
2626#include " mlir/IR/Visitors.h"
2727#include " mlir/Support/LLVM.h"
28+ #include " mlir/Transforms/DialectConversion.h"
29+ #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
2830#include " llvm/ADT/DenseMap.h"
2931#include " llvm/ADT/TypeSwitch.h"
3032#include " llvm/Support/CommandLine.h"
@@ -73,7 +75,6 @@ struct ParallelUnroll
7375private:
7476 // map from original memory definition to newly allocated banks
7577 DenseMap<Value, SmallVector<Value>> memoryToBanks;
76- SmallVector<Operation *, 8 > opsToErase;
7778};
7879} // namespace
7980
@@ -163,133 +164,191 @@ Value computeIntraBankingOffset(OpBuilder &builder, Location loc, Value address,
163164 return offset;
164165}
165166
166- // / Unrolls a 'affine.parallel' op. Returns success if the loop was unrolled,
167- // / failure otherwise. The default unroll factor is 4.
168- LogicalResult ParallelUnroll::parallelUnrollByFactor (AffineParallelOp parOp,
169- uint64_t unrollFactor) {
170- // 1. identify memrefs in the parallel region,
171- // 2. create memory banks for each of those memories
172- // 2.1 maybe result of alloc/getglobal, etc
173- // 2.2 maybe block arguments
174- //
167+ struct BankAffineLoadPattern : public OpRewritePattern <AffineLoadOp> {
168+ BankAffineLoadPattern (MLIRContext *context, uint64_t unrollFactor,
169+ DenseMap<Value, SmallVector<Value>> &memoryToBanks)
170+ : OpRewritePattern<AffineLoadOp>(context), unrollFactor(unrollFactor),
171+ memoryToBanks (memoryToBanks) {}
172+
173+ LogicalResult matchAndRewrite (AffineLoadOp loadOp,
174+ PatternRewriter &rewriter) const override {
175+ llvm::errs () << " load pattern matchAndRewrite\n " ;
176+ Location loc = loadOp.getLoc ();
177+ auto banks = memoryToBanks[loadOp.getMemref ()];
178+ Value loadIndex = loadOp.getIndices ().front ();
179+ rewriter.setInsertionPointToStart (loadOp->getBlock ());
180+ Value bankingFactorValue =
181+ rewriter.create <mlir::arith::ConstantIndexOp>(loc, unrollFactor);
182+ Value bankIndex = rewriter.create <mlir::arith::RemUIOp>(loc, loadIndex,
183+ bankingFactorValue);
184+ Value offset =
185+ computeIntraBankingOffset (rewriter, loc, loadIndex, unrollFactor);
186+
187+ SmallVector<Type> resultTypes = {loadOp.getResult ().getType ()};
188+
189+ SmallVector<int64_t , 4 > caseValues;
190+ for (unsigned i = 0 ; i < unrollFactor; ++i)
191+ caseValues.push_back (i);
192+
193+ rewriter.setInsertionPoint (loadOp);
194+ scf::IndexSwitchOp switchOp = rewriter.create <scf::IndexSwitchOp>(
195+ loc, resultTypes, bankIndex, caseValues,
196+ /* numRegions=*/ unrollFactor);
175197
176- DenseSet<Value> memrefsInPar = collectMemRefs (parOp);
177- Location loc = parOp.getLoc ();
178- OpBuilder builder (parOp);
198+ for (unsigned i = 0 ; i < unrollFactor; ++i) {
199+ Region &caseRegion = switchOp.getCaseRegions ()[i];
200+ rewriter.setInsertionPointToStart (&caseRegion.emplaceBlock ());
201+ Value bankedLoad = rewriter.create <AffineLoadOp>(loc, banks[i], offset);
202+ rewriter.create <scf::YieldOp>(loc, bankedLoad);
203+ }
179204
180- DenseSet<Block *> blocksToModify;
181- for (auto memrefVal : memrefsInPar) {
182- SmallVector<Value> banks = createBanks (memrefVal, unrollFactor);
183- memoryToBanks[memrefVal] = banks;
184-
185- for (auto *user : memrefVal.getUsers ()) {
186- // if user is within parallel region
187- TypeSwitch<Operation *>(user)
188- .Case <affine::AffineLoadOp>([&](affine::AffineLoadOp loadOp) {
189- Value loadIndex = loadOp.getIndices ().front ();
190- builder.setInsertionPointToStart (parOp.getBody ());
191- Value bankingFactorValue =
192- builder.create <mlir::arith::ConstantIndexOp>(loc, unrollFactor);
193- Value bankIndex = builder.create <mlir::arith::RemUIOp>(
194- loc, loadIndex, bankingFactorValue);
195- Value offset = computeIntraBankingOffset (builder, loc, loadIndex,
196- unrollFactor);
197-
198- SmallVector<Type> resultTypes = {loadOp.getResult ().getType ()};
199-
200- SmallVector<int64_t , 4 > caseValues;
201- for (unsigned i = 0 ; i < unrollFactor; ++i)
202- caseValues.push_back (i);
203-
204- builder.setInsertionPoint (user);
205- scf::IndexSwitchOp switchOp = builder.create <scf::IndexSwitchOp>(
206- loc, resultTypes, bankIndex, caseValues,
207- /* numRegions=*/ unrollFactor);
208-
209- for (unsigned i = 0 ; i < unrollFactor; ++i) {
210- Region &caseRegion = switchOp.getCaseRegions ()[i];
211- builder.setInsertionPointToStart (&caseRegion.emplaceBlock ());
212- Value bankedLoad =
213- builder.create <AffineLoadOp>(loc, banks[i], offset);
214- builder.create <scf::YieldOp>(loc, bankedLoad);
215- }
216-
217- Region &defaultRegion = switchOp.getDefaultRegion ();
218- assert (defaultRegion.empty () && " Default region should be empty" );
219- builder.setInsertionPointToStart (&defaultRegion.emplaceBlock ());
220-
221- TypedAttr zeroAttr =
222- cast<TypedAttr>(builder.getZeroAttr (loadOp.getType ()));
223- auto defaultValue =
224- builder.create <arith::ConstantOp>(loc, zeroAttr);
225- builder.create <scf::YieldOp>(loc, defaultValue.getResult ());
226-
227- loadOp.getResult ().replaceAllUsesWith (switchOp.getResult (0 ));
228-
229- user->erase ();
230- })
231- .Case <affine::AffineStoreOp>([&](affine::AffineStoreOp storeOp) {
232- Value loadIndex = storeOp.getIndices ().front ();
233- builder.setInsertionPointToStart (parOp.getBody ());
234- Value bankingFactorValue =
235- builder.create <mlir::arith::ConstantIndexOp>(loc, unrollFactor);
236- Value bankIndex = builder.create <mlir::arith::RemUIOp>(
237- loc, loadIndex, bankingFactorValue);
238- Value offset = computeIntraBankingOffset (builder, loc, loadIndex,
239- unrollFactor);
240-
241- SmallVector<Type> resultTypes = {};
242-
243- SmallVector<int64_t , 4 > caseValues;
244- for (unsigned i = 0 ; i < unrollFactor; ++i)
245- caseValues.push_back (i);
246-
247- builder.setInsertionPoint (user);
248- scf::IndexSwitchOp switchOp = builder.create <scf::IndexSwitchOp>(
249- loc, resultTypes, bankIndex, caseValues,
250- /* numRegions=*/ unrollFactor);
251-
252- for (unsigned i = 0 ; i < unrollFactor; ++i) {
253- Region &caseRegion = switchOp.getCaseRegions ()[i];
254- builder.setInsertionPointToStart (&caseRegion.emplaceBlock ());
255- builder.create <AffineStoreOp>(loc, storeOp.getValueToStore (),
256- banks[i], offset);
257- builder.create <scf::YieldOp>(loc);
258- }
259-
260- Region &defaultRegion = switchOp.getDefaultRegion ();
261- assert (defaultRegion.empty () && " Default region should be empty" );
262- builder.setInsertionPointToStart (&defaultRegion.emplaceBlock ());
263-
264- builder.create <scf::YieldOp>(loc);
265-
266- user->erase ();
267- })
268- .Default ([](Operation *op) {
269- op->emitWarning (" Unhandled operation type" );
270- op->dump ();
271- });
205+ Region &defaultRegion = switchOp.getDefaultRegion ();
206+ assert (defaultRegion.empty () && " Default region should be empty" );
207+ rewriter.setInsertionPointToStart (&defaultRegion.emplaceBlock ());
208+
209+ TypedAttr zeroAttr =
210+ cast<TypedAttr>(rewriter.getZeroAttr (loadOp.getType ()));
211+ auto defaultValue = rewriter.create <arith::ConstantOp>(loc, zeroAttr);
212+ rewriter.create <scf::YieldOp>(loc, defaultValue.getResult ());
213+
214+ loadOp.getResult ().replaceAllUsesWith (switchOp.getResult (0 ));
215+
216+ rewriter.eraseOp (loadOp);
217+ return success ();
218+ }
219+
220+ private:
221+ uint64_t unrollFactor;
222+ DenseMap<Value, SmallVector<Value>> &memoryToBanks;
223+ };
224+
225+ struct BankAffineStorePattern : public OpRewritePattern <AffineStoreOp> {
226+ BankAffineStorePattern (MLIRContext *context, uint64_t unrollFactor,
227+ DenseMap<Value, SmallVector<Value>> &memoryToBanks)
228+ : OpRewritePattern<AffineStoreOp>(context), unrollFactor(unrollFactor),
229+ memoryToBanks (memoryToBanks) {}
230+
231+ LogicalResult matchAndRewrite (AffineStoreOp storeOp,
232+ PatternRewriter &rewriter) const override {
233+ llvm::errs () << " store pattern matchAndRewrite\n " ;
234+ Location loc = storeOp.getLoc ();
235+ auto banks = memoryToBanks[storeOp.getMemref ()];
236+ Value loadIndex = storeOp.getIndices ().front ();
237+ rewriter.setInsertionPointToStart (storeOp->getBlock ());
238+ Value bankingFactorValue =
239+ rewriter.create <mlir::arith::ConstantIndexOp>(loc, unrollFactor);
240+ Value bankIndex = rewriter.create <mlir::arith::RemUIOp>(loc, loadIndex,
241+ bankingFactorValue);
242+ Value offset =
243+ computeIntraBankingOffset (rewriter, loc, loadIndex, unrollFactor);
244+
245+ SmallVector<Type> resultTypes = {};
246+
247+ SmallVector<int64_t , 4 > caseValues;
248+ for (unsigned i = 0 ; i < unrollFactor; ++i)
249+ caseValues.push_back (i);
250+
251+ rewriter.setInsertionPoint (storeOp);
252+ scf::IndexSwitchOp switchOp = rewriter.create <scf::IndexSwitchOp>(
253+ loc, resultTypes, bankIndex, caseValues,
254+ /* numRegions=*/ unrollFactor);
255+
256+ for (unsigned i = 0 ; i < unrollFactor; ++i) {
257+ Region &caseRegion = switchOp.getCaseRegions ()[i];
258+ rewriter.setInsertionPointToStart (&caseRegion.emplaceBlock ());
259+ rewriter.create <AffineStoreOp>(loc, storeOp.getValueToStore (), banks[i],
260+ offset);
261+ rewriter.create <scf::YieldOp>(loc);
272262 }
273263
274- for (auto *user : memrefVal.getUsers ()) {
275- if (auto returnOp = dyn_cast<func::ReturnOp>(user)) {
276- OpBuilder builder (returnOp);
277- func::FuncOp funcOp = returnOp.getParentOp ();
278- builder.setInsertionPointToEnd (&funcOp.getBlocks ().front ());
279- auto newReturnOp =
280- builder.create <func::ReturnOp>(loc, ValueRange (banks));
281- TypeRange newReturnType = TypeRange (banks);
282- FunctionType newFuncType = FunctionType::get (
283- funcOp.getContext (), funcOp.getFunctionType ().getInputs (),
284- newReturnType);
285- funcOp.setType (newFuncType);
286- returnOp->replaceAllUsesWith (newReturnOp);
287- opsToErase.push_back (returnOp);
264+ Region &defaultRegion = switchOp.getDefaultRegion ();
265+ assert (defaultRegion.empty () && " Default region should be empty" );
266+ rewriter.setInsertionPointToStart (&defaultRegion.emplaceBlock ());
267+
268+ rewriter.create <scf::YieldOp>(loc);
269+
270+ rewriter.eraseOp (storeOp);
271+ return success ();
272+ }
273+
274+ private:
275+ uint64_t unrollFactor;
276+ DenseMap<Value, SmallVector<Value>> &memoryToBanks;
277+ };
278+
279+ struct BankReturnPattern : public OpRewritePattern <func::ReturnOp> {
280+ BankReturnPattern (MLIRContext *context,
281+ DenseMap<Value, SmallVector<Value>> &memoryToBanks)
282+ : OpRewritePattern<func::ReturnOp>(context),
283+ memoryToBanks (memoryToBanks) {}
284+
285+ LogicalResult matchAndRewrite (func::ReturnOp returnOp,
286+ PatternRewriter &rewriter) const override {
287+ Location loc = returnOp.getLoc ();
288+ SmallVector<Value, 4 > newReturnOperands;
289+ bool allOrigMemsUsedByReturn = true ;
290+ for (auto operand : returnOp.getOperands ()) {
291+ if (!memoryToBanks.contains (operand)) {
292+ newReturnOperands.push_back (operand);
293+ continue ;
288294 }
295+ if (operand.hasOneUse ())
296+ allOrigMemsUsedByReturn = false ;
297+ auto banks = memoryToBanks[operand];
298+ newReturnOperands.append (banks.begin (), banks.end ());
299+ }
300+ func::FuncOp funcOp = returnOp.getParentOp ();
301+ rewriter.setInsertionPointToEnd (&funcOp.getBlocks ().front ());
302+ auto newReturnOp =
303+ rewriter.create <func::ReturnOp>(loc, ValueRange (newReturnOperands));
304+ TypeRange newReturnType = TypeRange (newReturnOperands);
305+ FunctionType newFuncType =
306+ FunctionType::get (funcOp.getContext (),
307+ funcOp.getFunctionType ().getInputs (), newReturnType);
308+ funcOp.setType (newFuncType);
309+
310+ if (allOrigMemsUsedByReturn) {
311+ rewriter.replaceOp (returnOp, newReturnOp);
289312 }
313+ return success ();
314+ }
290315
291- // TODO: if use is empty, we should delete the original block args; and
292- // reset function type
316+ private:
317+ DenseMap<Value, SmallVector<Value>> &memoryToBanks;
318+ };
319+
320+ void ParallelUnroll::runOnOperation () {
321+ if (getOperation ().isExternal ()) {
322+ return ;
323+ }
324+
325+ getOperation ().walk ([&](AffineParallelOp parOp) {
326+ DenseSet<Value> memrefsInPar = collectMemRefs (parOp);
327+
328+ for (auto memrefVal : memrefsInPar) {
329+ SmallVector<Value> banks = createBanks (memrefVal, unrollFactor);
330+ memoryToBanks[memrefVal] = banks;
331+ }
332+ });
333+
334+ auto *ctx = &getContext ();
335+
336+ RewritePatternSet patterns (ctx);
337+
338+ patterns.add <BankAffineLoadPattern>(ctx, unrollFactor, memoryToBanks);
339+ patterns.add <BankAffineStorePattern>(ctx, unrollFactor, memoryToBanks);
340+ patterns.add <BankReturnPattern>(ctx, memoryToBanks);
341+
342+ GreedyRewriteConfig config;
343+ config.strictMode = GreedyRewriteStrictness::ExistingOps;
344+
345+ if (failed (applyPatternsAndFoldGreedily (getOperation (), std::move (patterns),
346+ config))) {
347+ signalPassFailure ();
348+ }
349+
350+ DenseSet<Block *> blocksToModify;
351+ for (auto &[memrefVal, banks] : memoryToBanks) {
293352 if (memrefVal.use_empty ()) {
294353 if (auto blockArg = dyn_cast<BlockArgument>(memrefVal)) {
295354 blockArg.getOwner ()->eraseArgument (blockArg.getArgNumber ());
@@ -314,32 +373,7 @@ LogicalResult ParallelUnroll::parallelUnrollByFactor(AffineParallelOp parOp,
314373 funcOp.setType (newFuncType);
315374 }
316375
317- // / - `isDefinedOutsideRegion` returns true if the given value is invariant
318- // / with
319- // / respect to the given region. A common implementation might be:
320- // / `value.getParentRegion()->isProperAncestor(region)`.
321-
322- if (unrollFactor == 1 ) {
323- // TODO: how to address "expected pattern to replace the root operation" if
324- // just simply return success
325- return success ();
326- }
327-
328- return success ();
329- }
330-
331- void ParallelUnroll::runOnOperation () {
332- if (getOperation ().isExternal ()) {
333- return ;
334- }
335-
336- getOperation ().walk ([&](AffineParallelOp parOp) {
337- (void )parallelUnrollByFactor (parOp, unrollFactor);
338- return WalkResult::advance ();
339- });
340- for (auto *op : opsToErase) {
341- op->erase ();
342- }
376+ getOperation ().dump ();
343377}
344378
345379std::unique_ptr<OperationPass<func::FuncOp>>
0 commit comments