Skip to content

Commit 9bcc13c

Browse files
committed
use rewriter pattern partially working
1 parent 045d377 commit 9bcc13c

File tree

1 file changed

+181
-147
lines changed

1 file changed

+181
-147
lines changed

mlir/lib/Dialect/Affine/Transforms/ParallelUnroll.cpp

Lines changed: 181 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
#include "mlir/IR/IRMapping.h"
2626
#include "mlir/IR/Visitors.h"
2727
#include "mlir/Support/LLVM.h"
28+
#include "mlir/Transforms/DialectConversion.h"
29+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2830
#include "llvm/ADT/DenseMap.h"
2931
#include "llvm/ADT/TypeSwitch.h"
3032
#include "llvm/Support/CommandLine.h"
@@ -73,7 +75,6 @@ struct ParallelUnroll
7375
private:
7476
// map from original memory definition to newly allocated banks
7577
DenseMap<Value, SmallVector<Value>> memoryToBanks;
76-
SmallVector<Operation *, 8> opsToErase;
7778
};
7879
} // namespace
7980

@@ -163,133 +164,191 @@ Value computeIntraBankingOffset(OpBuilder &builder, Location loc, Value address,
163164
return offset;
164165
}
165166

166-
/// Unrolls a 'affine.parallel' op. Returns success if the loop was unrolled,
167-
/// failure otherwise. The default unroll factor is 4.
168-
LogicalResult ParallelUnroll::parallelUnrollByFactor(AffineParallelOp parOp,
169-
uint64_t unrollFactor) {
170-
// 1. identify memrefs in the parallel region,
171-
// 2. create memory banks for each of those memories
172-
// 2.1 maybe result of alloc/getglobal, etc
173-
// 2.2 maybe block arguments
174-
//
167+
struct BankAffineLoadPattern : public OpRewritePattern<AffineLoadOp> {
168+
BankAffineLoadPattern(MLIRContext *context, uint64_t unrollFactor,
169+
DenseMap<Value, SmallVector<Value>> &memoryToBanks)
170+
: OpRewritePattern<AffineLoadOp>(context), unrollFactor(unrollFactor),
171+
memoryToBanks(memoryToBanks) {}
172+
173+
LogicalResult matchAndRewrite(AffineLoadOp loadOp,
174+
PatternRewriter &rewriter) const override {
175+
llvm::errs() << "load pattern matchAndRewrite\n";
176+
Location loc = loadOp.getLoc();
177+
auto banks = memoryToBanks[loadOp.getMemref()];
178+
Value loadIndex = loadOp.getIndices().front();
179+
rewriter.setInsertionPointToStart(loadOp->getBlock());
180+
Value bankingFactorValue =
181+
rewriter.create<mlir::arith::ConstantIndexOp>(loc, unrollFactor);
182+
Value bankIndex = rewriter.create<mlir::arith::RemUIOp>(loc, loadIndex,
183+
bankingFactorValue);
184+
Value offset =
185+
computeIntraBankingOffset(rewriter, loc, loadIndex, unrollFactor);
186+
187+
SmallVector<Type> resultTypes = {loadOp.getResult().getType()};
188+
189+
SmallVector<int64_t, 4> caseValues;
190+
for (unsigned i = 0; i < unrollFactor; ++i)
191+
caseValues.push_back(i);
192+
193+
rewriter.setInsertionPoint(loadOp);
194+
scf::IndexSwitchOp switchOp = rewriter.create<scf::IndexSwitchOp>(
195+
loc, resultTypes, bankIndex, caseValues,
196+
/*numRegions=*/unrollFactor);
175197

176-
DenseSet<Value> memrefsInPar = collectMemRefs(parOp);
177-
Location loc = parOp.getLoc();
178-
OpBuilder builder(parOp);
198+
for (unsigned i = 0; i < unrollFactor; ++i) {
199+
Region &caseRegion = switchOp.getCaseRegions()[i];
200+
rewriter.setInsertionPointToStart(&caseRegion.emplaceBlock());
201+
Value bankedLoad = rewriter.create<AffineLoadOp>(loc, banks[i], offset);
202+
rewriter.create<scf::YieldOp>(loc, bankedLoad);
203+
}
179204

180-
DenseSet<Block *> blocksToModify;
181-
for (auto memrefVal : memrefsInPar) {
182-
SmallVector<Value> banks = createBanks(memrefVal, unrollFactor);
183-
memoryToBanks[memrefVal] = banks;
184-
185-
for (auto *user : memrefVal.getUsers()) {
186-
// if user is within parallel region
187-
TypeSwitch<Operation *>(user)
188-
.Case<affine::AffineLoadOp>([&](affine::AffineLoadOp loadOp) {
189-
Value loadIndex = loadOp.getIndices().front();
190-
builder.setInsertionPointToStart(parOp.getBody());
191-
Value bankingFactorValue =
192-
builder.create<mlir::arith::ConstantIndexOp>(loc, unrollFactor);
193-
Value bankIndex = builder.create<mlir::arith::RemUIOp>(
194-
loc, loadIndex, bankingFactorValue);
195-
Value offset = computeIntraBankingOffset(builder, loc, loadIndex,
196-
unrollFactor);
197-
198-
SmallVector<Type> resultTypes = {loadOp.getResult().getType()};
199-
200-
SmallVector<int64_t, 4> caseValues;
201-
for (unsigned i = 0; i < unrollFactor; ++i)
202-
caseValues.push_back(i);
203-
204-
builder.setInsertionPoint(user);
205-
scf::IndexSwitchOp switchOp = builder.create<scf::IndexSwitchOp>(
206-
loc, resultTypes, bankIndex, caseValues,
207-
/*numRegions=*/unrollFactor);
208-
209-
for (unsigned i = 0; i < unrollFactor; ++i) {
210-
Region &caseRegion = switchOp.getCaseRegions()[i];
211-
builder.setInsertionPointToStart(&caseRegion.emplaceBlock());
212-
Value bankedLoad =
213-
builder.create<AffineLoadOp>(loc, banks[i], offset);
214-
builder.create<scf::YieldOp>(loc, bankedLoad);
215-
}
216-
217-
Region &defaultRegion = switchOp.getDefaultRegion();
218-
assert(defaultRegion.empty() && "Default region should be empty");
219-
builder.setInsertionPointToStart(&defaultRegion.emplaceBlock());
220-
221-
TypedAttr zeroAttr =
222-
cast<TypedAttr>(builder.getZeroAttr(loadOp.getType()));
223-
auto defaultValue =
224-
builder.create<arith::ConstantOp>(loc, zeroAttr);
225-
builder.create<scf::YieldOp>(loc, defaultValue.getResult());
226-
227-
loadOp.getResult().replaceAllUsesWith(switchOp.getResult(0));
228-
229-
user->erase();
230-
})
231-
.Case<affine::AffineStoreOp>([&](affine::AffineStoreOp storeOp) {
232-
Value loadIndex = storeOp.getIndices().front();
233-
builder.setInsertionPointToStart(parOp.getBody());
234-
Value bankingFactorValue =
235-
builder.create<mlir::arith::ConstantIndexOp>(loc, unrollFactor);
236-
Value bankIndex = builder.create<mlir::arith::RemUIOp>(
237-
loc, loadIndex, bankingFactorValue);
238-
Value offset = computeIntraBankingOffset(builder, loc, loadIndex,
239-
unrollFactor);
240-
241-
SmallVector<Type> resultTypes = {};
242-
243-
SmallVector<int64_t, 4> caseValues;
244-
for (unsigned i = 0; i < unrollFactor; ++i)
245-
caseValues.push_back(i);
246-
247-
builder.setInsertionPoint(user);
248-
scf::IndexSwitchOp switchOp = builder.create<scf::IndexSwitchOp>(
249-
loc, resultTypes, bankIndex, caseValues,
250-
/*numRegions=*/unrollFactor);
251-
252-
for (unsigned i = 0; i < unrollFactor; ++i) {
253-
Region &caseRegion = switchOp.getCaseRegions()[i];
254-
builder.setInsertionPointToStart(&caseRegion.emplaceBlock());
255-
builder.create<AffineStoreOp>(loc, storeOp.getValueToStore(),
256-
banks[i], offset);
257-
builder.create<scf::YieldOp>(loc);
258-
}
259-
260-
Region &defaultRegion = switchOp.getDefaultRegion();
261-
assert(defaultRegion.empty() && "Default region should be empty");
262-
builder.setInsertionPointToStart(&defaultRegion.emplaceBlock());
263-
264-
builder.create<scf::YieldOp>(loc);
265-
266-
user->erase();
267-
})
268-
.Default([](Operation *op) {
269-
op->emitWarning("Unhandled operation type");
270-
op->dump();
271-
});
205+
Region &defaultRegion = switchOp.getDefaultRegion();
206+
assert(defaultRegion.empty() && "Default region should be empty");
207+
rewriter.setInsertionPointToStart(&defaultRegion.emplaceBlock());
208+
209+
TypedAttr zeroAttr =
210+
cast<TypedAttr>(rewriter.getZeroAttr(loadOp.getType()));
211+
auto defaultValue = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
212+
rewriter.create<scf::YieldOp>(loc, defaultValue.getResult());
213+
214+
loadOp.getResult().replaceAllUsesWith(switchOp.getResult(0));
215+
216+
rewriter.eraseOp(loadOp);
217+
return success();
218+
}
219+
220+
private:
221+
uint64_t unrollFactor;
222+
DenseMap<Value, SmallVector<Value>> &memoryToBanks;
223+
};
224+
225+
struct BankAffineStorePattern : public OpRewritePattern<AffineStoreOp> {
226+
BankAffineStorePattern(MLIRContext *context, uint64_t unrollFactor,
227+
DenseMap<Value, SmallVector<Value>> &memoryToBanks)
228+
: OpRewritePattern<AffineStoreOp>(context), unrollFactor(unrollFactor),
229+
memoryToBanks(memoryToBanks) {}
230+
231+
LogicalResult matchAndRewrite(AffineStoreOp storeOp,
232+
PatternRewriter &rewriter) const override {
233+
llvm::errs() << "store pattern matchAndRewrite\n";
234+
Location loc = storeOp.getLoc();
235+
auto banks = memoryToBanks[storeOp.getMemref()];
236+
Value loadIndex = storeOp.getIndices().front();
237+
rewriter.setInsertionPointToStart(storeOp->getBlock());
238+
Value bankingFactorValue =
239+
rewriter.create<mlir::arith::ConstantIndexOp>(loc, unrollFactor);
240+
Value bankIndex = rewriter.create<mlir::arith::RemUIOp>(loc, loadIndex,
241+
bankingFactorValue);
242+
Value offset =
243+
computeIntraBankingOffset(rewriter, loc, loadIndex, unrollFactor);
244+
245+
SmallVector<Type> resultTypes = {};
246+
247+
SmallVector<int64_t, 4> caseValues;
248+
for (unsigned i = 0; i < unrollFactor; ++i)
249+
caseValues.push_back(i);
250+
251+
rewriter.setInsertionPoint(storeOp);
252+
scf::IndexSwitchOp switchOp = rewriter.create<scf::IndexSwitchOp>(
253+
loc, resultTypes, bankIndex, caseValues,
254+
/*numRegions=*/unrollFactor);
255+
256+
for (unsigned i = 0; i < unrollFactor; ++i) {
257+
Region &caseRegion = switchOp.getCaseRegions()[i];
258+
rewriter.setInsertionPointToStart(&caseRegion.emplaceBlock());
259+
rewriter.create<AffineStoreOp>(loc, storeOp.getValueToStore(), banks[i],
260+
offset);
261+
rewriter.create<scf::YieldOp>(loc);
272262
}
273263

274-
for (auto *user : memrefVal.getUsers()) {
275-
if (auto returnOp = dyn_cast<func::ReturnOp>(user)) {
276-
OpBuilder builder(returnOp);
277-
func::FuncOp funcOp = returnOp.getParentOp();
278-
builder.setInsertionPointToEnd(&funcOp.getBlocks().front());
279-
auto newReturnOp =
280-
builder.create<func::ReturnOp>(loc, ValueRange(banks));
281-
TypeRange newReturnType = TypeRange(banks);
282-
FunctionType newFuncType = FunctionType::get(
283-
funcOp.getContext(), funcOp.getFunctionType().getInputs(),
284-
newReturnType);
285-
funcOp.setType(newFuncType);
286-
returnOp->replaceAllUsesWith(newReturnOp);
287-
opsToErase.push_back(returnOp);
264+
Region &defaultRegion = switchOp.getDefaultRegion();
265+
assert(defaultRegion.empty() && "Default region should be empty");
266+
rewriter.setInsertionPointToStart(&defaultRegion.emplaceBlock());
267+
268+
rewriter.create<scf::YieldOp>(loc);
269+
270+
rewriter.eraseOp(storeOp);
271+
return success();
272+
}
273+
274+
private:
275+
uint64_t unrollFactor;
276+
DenseMap<Value, SmallVector<Value>> &memoryToBanks;
277+
};
278+
279+
struct BankReturnPattern : public OpRewritePattern<func::ReturnOp> {
280+
BankReturnPattern(MLIRContext *context,
281+
DenseMap<Value, SmallVector<Value>> &memoryToBanks)
282+
: OpRewritePattern<func::ReturnOp>(context),
283+
memoryToBanks(memoryToBanks) {}
284+
285+
LogicalResult matchAndRewrite(func::ReturnOp returnOp,
286+
PatternRewriter &rewriter) const override {
287+
Location loc = returnOp.getLoc();
288+
SmallVector<Value, 4> newReturnOperands;
289+
bool allOrigMemsUsedByReturn = true;
290+
for (auto operand : returnOp.getOperands()) {
291+
if (!memoryToBanks.contains(operand)) {
292+
newReturnOperands.push_back(operand);
293+
continue;
288294
}
295+
if (operand.hasOneUse())
296+
allOrigMemsUsedByReturn = false;
297+
auto banks = memoryToBanks[operand];
298+
newReturnOperands.append(banks.begin(), banks.end());
299+
}
300+
func::FuncOp funcOp = returnOp.getParentOp();
301+
rewriter.setInsertionPointToEnd(&funcOp.getBlocks().front());
302+
auto newReturnOp =
303+
rewriter.create<func::ReturnOp>(loc, ValueRange(newReturnOperands));
304+
TypeRange newReturnType = TypeRange(newReturnOperands);
305+
FunctionType newFuncType =
306+
FunctionType::get(funcOp.getContext(),
307+
funcOp.getFunctionType().getInputs(), newReturnType);
308+
funcOp.setType(newFuncType);
309+
310+
if (allOrigMemsUsedByReturn) {
311+
rewriter.replaceOp(returnOp, newReturnOp);
289312
}
313+
return success();
314+
}
290315

291-
// TODO: if use is empty, we should delete the original block args; and
292-
// reset function type
316+
private:
317+
DenseMap<Value, SmallVector<Value>> &memoryToBanks;
318+
};
319+
320+
void ParallelUnroll::runOnOperation() {
321+
if (getOperation().isExternal()) {
322+
return;
323+
}
324+
325+
getOperation().walk([&](AffineParallelOp parOp) {
326+
DenseSet<Value> memrefsInPar = collectMemRefs(parOp);
327+
328+
for (auto memrefVal : memrefsInPar) {
329+
SmallVector<Value> banks = createBanks(memrefVal, unrollFactor);
330+
memoryToBanks[memrefVal] = banks;
331+
}
332+
});
333+
334+
auto *ctx = &getContext();
335+
336+
RewritePatternSet patterns(ctx);
337+
338+
patterns.add<BankAffineLoadPattern>(ctx, unrollFactor, memoryToBanks);
339+
patterns.add<BankAffineStorePattern>(ctx, unrollFactor, memoryToBanks);
340+
patterns.add<BankReturnPattern>(ctx, memoryToBanks);
341+
342+
GreedyRewriteConfig config;
343+
config.strictMode = GreedyRewriteStrictness::ExistingOps;
344+
345+
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
346+
config))) {
347+
signalPassFailure();
348+
}
349+
350+
DenseSet<Block *> blocksToModify;
351+
for (auto &[memrefVal, banks] : memoryToBanks) {
293352
if (memrefVal.use_empty()) {
294353
if (auto blockArg = dyn_cast<BlockArgument>(memrefVal)) {
295354
blockArg.getOwner()->eraseArgument(blockArg.getArgNumber());
@@ -314,32 +373,7 @@ LogicalResult ParallelUnroll::parallelUnrollByFactor(AffineParallelOp parOp,
314373
funcOp.setType(newFuncType);
315374
}
316375

317-
/// - `isDefinedOutsideRegion` returns true if the given value is invariant
318-
/// with
319-
/// respect to the given region. A common implementation might be:
320-
/// `value.getParentRegion()->isProperAncestor(region)`.
321-
322-
if (unrollFactor == 1) {
323-
// TODO: how to address "expected pattern to replace the root operation" if
324-
// just simply return success
325-
return success();
326-
}
327-
328-
return success();
329-
}
330-
331-
void ParallelUnroll::runOnOperation() {
332-
if (getOperation().isExternal()) {
333-
return;
334-
}
335-
336-
getOperation().walk([&](AffineParallelOp parOp) {
337-
(void)parallelUnrollByFactor(parOp, unrollFactor);
338-
return WalkResult::advance();
339-
});
340-
for (auto *op : opsToErase) {
341-
op->erase();
342-
}
376+
getOperation().dump();
343377
}
344378

345379
std::unique_ptr<OperationPass<func::FuncOp>>

0 commit comments

Comments
 (0)