Skip to content

Commit b65b2df

Browse files
committed
use affine map to load
1 parent 9bcc13c commit b65b2df

File tree

1 file changed

+18
-15
lines changed

1 file changed

+18
-15
lines changed

mlir/lib/Dialect/Affine/Transforms/ParallelUnroll.cpp

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -176,13 +176,14 @@ struct BankAffineLoadPattern : public OpRewritePattern<AffineLoadOp> {
176176
Location loc = loadOp.getLoc();
177177
auto banks = memoryToBanks[loadOp.getMemref()];
178178
Value loadIndex = loadOp.getIndices().front();
179-
rewriter.setInsertionPointToStart(loadOp->getBlock());
180-
Value bankingFactorValue =
181-
rewriter.create<mlir::arith::ConstantIndexOp>(loc, unrollFactor);
182-
Value bankIndex = rewriter.create<mlir::arith::RemUIOp>(loc, loadIndex,
183-
bankingFactorValue);
184-
Value offset =
185-
computeIntraBankingOffset(rewriter, loc, loadIndex, unrollFactor);
179+
auto modMap =
180+
AffineMap::get(1, 0, {rewriter.getAffineDimExpr(0) % unrollFactor});
181+
auto divMap = AffineMap::get(
182+
1, 0, {rewriter.getAffineDimExpr(0).floorDiv(unrollFactor)});
183+
184+
Value bankIndex = rewriter.create<AffineApplyOp>(
185+
loc, modMap, loadIndex); // assuming one-dim
186+
Value offset = rewriter.create<AffineApplyOp>(loc, divMap, loadIndex);
186187

187188
SmallVector<Type> resultTypes = {loadOp.getResult().getType()};
188189

@@ -233,14 +234,16 @@ struct BankAffineStorePattern : public OpRewritePattern<AffineStoreOp> {
233234
llvm::errs() << "store pattern matchAndRewrite\n";
234235
Location loc = storeOp.getLoc();
235236
auto banks = memoryToBanks[storeOp.getMemref()];
236-
Value loadIndex = storeOp.getIndices().front();
237-
rewriter.setInsertionPointToStart(storeOp->getBlock());
238-
Value bankingFactorValue =
239-
rewriter.create<mlir::arith::ConstantIndexOp>(loc, unrollFactor);
240-
Value bankIndex = rewriter.create<mlir::arith::RemUIOp>(loc, loadIndex,
241-
bankingFactorValue);
242-
Value offset =
243-
computeIntraBankingOffset(rewriter, loc, loadIndex, unrollFactor);
237+
Value storeIndex = storeOp.getIndices().front();
238+
239+
auto modMap =
240+
AffineMap::get(1, 0, {rewriter.getAffineDimExpr(0) % unrollFactor});
241+
auto divMap = AffineMap::get(
242+
1, 0, {rewriter.getAffineDimExpr(0).floorDiv(unrollFactor)});
243+
244+
Value bankIndex = rewriter.create<AffineApplyOp>(
245+
loc, modMap, storeIndex); // assuming one-dim
246+
Value offset = rewriter.create<AffineApplyOp>(loc, divMap, storeIndex);
244247

245248
SmallVector<Type> resultTypes = {};
246249

0 commit comments

Comments
 (0)