Skip to content

Commit b160c49

Browse files
authored
Improves the blocking logic and fixing Issue #413 (#722)
1 parent 973ff43 commit b160c49

File tree

9 files changed

+416
-139
lines changed

9 files changed

+416
-139
lines changed

include/imex/Dialect/XeTile/Transforms/Blocking.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,8 @@ class XeTileConversion : public imex::XeConversionPattern<AnalysisT> {
7474
mlir::DenseI64ArrayAttr::get(src.getContext(), innerBlocks));
7575
}
7676

77-
xetile::TilePackOp addPackOp(mlir::Value src,
78-
llvm::ArrayRef<int64_t> targetBlkSizes,
79-
OpPatternRewriter &rewriter) const {
77+
mlir::Value addPackOp(mlir::Value src, llvm::ArrayRef<int64_t> targetBlkSizes,
78+
OpPatternRewriter &rewriter) const {
8079
auto srcTy = src.getType().dyn_cast<mlir::VectorType>();
8180
assert(srcTy && targetBlkSizes.size() == 2);
8281
auto shape = srcTy.getShape();

include/imex/Utils/XeCommon.h

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ class TileUsageAnalysis {
198198
// optimizations.
199199
class PropagateAnalysis {
200200
private:
201-
llvm::DenseMap<mlir::Operation *, mlir::DenseI64ArrayAttr> OpAttrMap;
201+
llvm::DenseMap<mlir::Value, mlir::DenseI64ArrayAttr> OpAttrMap;
202202

203203
public:
204204
PropagateAnalysis(mlir::Operation *op) {
@@ -214,11 +214,25 @@ class PropagateAnalysis {
214214
});
215215
}
216216

217-
bool maybeUpdated(mlir::Operation *op) { return OpAttrMap.count(op); }
217+
bool maybeUpdated(mlir::Operation *op) const {
218+
assert(op->getNumResults() == 1);
219+
auto v = op->getResult(0);
220+
return OpAttrMap.count(v);
221+
}
222+
223+
mlir::DenseI64ArrayAttr getValue(mlir::Value value) const {
224+
auto it = OpAttrMap.find(value);
225+
if (it != OpAttrMap.end())
226+
return it->second;
227+
return {};
228+
}
218229

219-
mlir::DenseI64ArrayAttr getValue(mlir::Operation *op) {
220-
if (OpAttrMap.count(op))
221-
return OpAttrMap[op];
230+
mlir::DenseI64ArrayAttr getValue(mlir::Operation *op) const {
231+
assert(op->getNumResults() == 1);
232+
auto v = op->getResult(0);
233+
auto it = OpAttrMap.find(v);
234+
if (it != OpAttrMap.end())
235+
return it->second;
222236
return {};
223237
}
224238

@@ -256,9 +270,9 @@ class PropagateAnalysis {
256270

257271
// stop when meet a function.
258272
if (!op || llvm::isa<mlir::FunctionOpInterface>(op))
259-
return;
273+
continue;
260274

261-
OpAttrMap[op] = attr;
275+
OpAttrMap[value] = attr;
262276

263277
if (auto forOp = llvm::dyn_cast<mlir::scf::ForOp>(op)) {
264278
auto opr = getOperandForArg(forOp, value);
@@ -447,6 +461,10 @@ class XeConversionPattern : public mlir::RewritePattern {
447461
return {};
448462
}
449463

464+
mlir::DenseI64ArrayAttr getValue(mlir::Value value) const {
465+
return llvm::cast<PropagateAnalysis>(analysis).getValue(value);
466+
}
467+
450468
template <typename = typename std::enable_if<
451469
std::is_same_v<AnalysisT, TileUsageAnalysis>>>
452470
bool isForDPASA(imex::xetile::LoadTileOp op) const {

lib/Dialect/XeTile/Transforms/BlockAligning.cpp

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -99,23 +99,29 @@ struct SCFForOpPattern
9999
::mlir::LogicalResult
100100
matchAndRewrite(mlir::scf::ForOp op, OpAdaptor adaptor,
101101
OpPatternRewriter &rewriter) const override {
102+
// we need to update the SCF ForOp if the types of its init arg values
103+
// do not match the types of the region iter args, or the init arg value
104+
// is defined by a TilePackOp. Otherwise we can skip the op.
105+
bool changed = false;
102106
llvm::SmallVector<mlir::Value> newInitArgs;
103-
llvm::SmallVector<mlir::DenseI64ArrayAttr> oldBlockSize;
104-
llvm::SmallVector<mlir::DenseI64ArrayAttr> newBlockSize;
105-
for (auto arg : adaptor.getInitArgs()) {
107+
llvm::SmallVector<mlir::DenseI64ArrayAttr> oldBlockSizes;
108+
llvm::SmallVector<mlir::DenseI64ArrayAttr> newBlockSizes;
109+
for (auto [i, arg] : llvm::enumerate(adaptor.getInitArgs())) {
110+
auto blockArg = op.getRegionIterArg(i);
106111
auto defOp = arg.getDefiningOp<xetile::TilePackOp>();
107-
if (auto blockSize = getValue(defOp)) {
108-
newBlockSize.push_back(blockSize);
109-
oldBlockSize.push_back(defOp.getInnerBlocksAttr());
110-
auto repackOp = addUnpackAndPackOps(arg, blockSize, rewriter);
111-
newInitArgs.push_back(repackOp);
112-
} else {
113-
oldBlockSize.push_back({});
114-
newBlockSize.push_back({});
115-
newInitArgs.push_back(arg);
116-
}
112+
auto oldSize = defOp ? defOp.getInnerBlocksAttr() : DenseI64ArrayAttr();
113+
auto newSize = defOp ? getValue(blockArg) : DenseI64ArrayAttr();
114+
auto newArg =
115+
defOp && newSize ? addUnpackAndPackOps(arg, newSize, rewriter) : arg;
116+
oldBlockSizes.push_back(oldSize);
117+
newBlockSizes.push_back(newSize);
118+
newInitArgs.push_back(newArg);
119+
changed |= (newArg.getType() != blockArg.getType());
117120
}
118121

122+
if (!changed)
123+
return mlir::failure();
124+
119125
auto newOp = rewriter.create<mlir::scf::ForOp>(
120126
op.getLoc(), adaptor.getLowerBound(), adaptor.getUpperBound(),
121127
adaptor.getStep(), newInitArgs);
@@ -124,15 +130,19 @@ struct SCFForOpPattern
124130
mlir::Block *newBlock = newOp.getBody();
125131
llvm::SmallVector<mlir::Value> newArguments;
126132
auto numCtrlOprs = newOp.getNumInductionVars();
133+
// remove the terminator of the new block
134+
if (newBlock->mightHaveTerminator())
135+
rewriter.eraseOp(newBlock->getTerminator());
136+
127137
// add UnpackOp and PackOp pairs to the block arguments
128138
// if the corresponding init arg is repacked, such that
129139
// the old unpack op using it in the body will be folded
130140
for (auto [i, arg] : llvm::enumerate(newBlock->getArguments())) {
131-
if (i < numCtrlOprs || !oldBlockSize[i - numCtrlOprs]) {
141+
if (i < numCtrlOprs || !oldBlockSizes[i - numCtrlOprs]) {
132142
newArguments.push_back(arg);
133143
} else {
134144
auto repackOp =
135-
addUnpackAndPackOps(arg, oldBlockSize[i - numCtrlOprs], rewriter);
145+
addUnpackAndPackOps(arg, oldBlockSizes[i - numCtrlOprs], rewriter);
136146
newArguments.push_back(repackOp);
137147
}
138148
}
@@ -145,9 +155,9 @@ struct SCFForOpPattern
145155
mlir::OpBuilder::InsertionGuard g(rewriter);
146156
rewriter.startOpModification(yieldOp);
147157
for (auto [i, v] : llvm::enumerate(yieldOp.getResults())) {
148-
if (newBlockSize[i]) {
158+
if (newBlockSizes[i]) {
149159
rewriter.setInsertionPointAfter(v.getDefiningOp());
150-
auto repack = addUnpackAndPackOps(v, newBlockSize[i], rewriter);
160+
auto repack = addUnpackAndPackOps(v, newBlockSizes[i], rewriter);
151161
yieldOp->setOperand(i, repack);
152162
}
153163
}
@@ -157,8 +167,8 @@ struct SCFForOpPattern
157167
rewriter.setInsertionPointAfter(op);
158168
llvm::SmallVector<mlir::Value> newValues;
159169
for (auto [i, result] : llvm::enumerate(newOp->getResults())) {
160-
if (newInitArgs[i].getDefiningOp<xetile::TilePackOp>()) {
161-
auto unpack = addUnpackAndPackOps(result, oldBlockSize[i], rewriter);
170+
if (oldBlockSizes[i]) {
171+
auto unpack = addUnpackAndPackOps(result, oldBlockSizes[i], rewriter);
162172
newValues.push_back(unpack);
163173
} else {
164174
newValues.push_back(result);
@@ -247,6 +257,9 @@ struct UpdateTileOffsetOpPattern
247257
::mlir::LogicalResult
248258
matchAndRewrite(xetile::UpdateTileOffsetOp op, OpAdaptor adaptor,
249259
OpPatternRewriter &rewriter) const override {
260+
if (adaptor.getTile().getType() == op.getResult().getType())
261+
return mlir::failure();
262+
250263
rewriter.replaceOpWithNewOp<xetile::UpdateTileOffsetOp>(
251264
op, adaptor.getTile().getType(), adaptor.getTile(),
252265
adaptor.getOffsetX(), adaptor.getOffsetY());
@@ -291,9 +304,9 @@ class XeTileBlockAligningPass : public imex::impl::XeTileBlockAligningBase<
291304
// Use TopDown traversal order, and only look at existing ops
292305
// to simpliy the code logic and speedup the pass
293306
mlir::GreedyRewriteConfig config;
307+
config.enableRegionSimplification = false;
294308
config.useTopDownTraversal = true;
295-
config.maxIterations = 2;
296-
config.strictMode = GreedyRewriteStrictness::ExistingOps;
309+
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
297310
if (failed(
298311
applyPatternsAndFoldGreedily(mod, std::move(patterns), config))) {
299312
return signalPassFailure();

0 commit comments

Comments
 (0)