@@ -99,23 +99,29 @@ struct SCFForOpPattern
99
99
::mlir::LogicalResult
100
100
matchAndRewrite (mlir::scf::ForOp op, OpAdaptor adaptor,
101
101
OpPatternRewriter &rewriter) const override {
102
+ // we need to update the SCF ForOp if the types of its init arg values
103
+ // do not match the types of the region iter args, or the init arg value
104
+ // is defined by a TilePackOp. Otherwise we can skip the op.
105
+ bool changed = false ;
102
106
llvm::SmallVector<mlir::Value> newInitArgs;
103
- llvm::SmallVector<mlir::DenseI64ArrayAttr> oldBlockSize;
104
- llvm::SmallVector<mlir::DenseI64ArrayAttr> newBlockSize;
105
- for (auto arg : adaptor.getInitArgs ()) {
107
+ llvm::SmallVector<mlir::DenseI64ArrayAttr> oldBlockSizes;
108
+ llvm::SmallVector<mlir::DenseI64ArrayAttr> newBlockSizes;
109
+ for (auto [i, arg] : llvm::enumerate (adaptor.getInitArgs ())) {
110
+ auto blockArg = op.getRegionIterArg (i);
106
111
auto defOp = arg.getDefiningOp <xetile::TilePackOp>();
107
- if (auto blockSize = getValue (defOp)) {
108
- newBlockSize.push_back (blockSize);
109
- oldBlockSize.push_back (defOp.getInnerBlocksAttr ());
110
- auto repackOp = addUnpackAndPackOps (arg, blockSize, rewriter);
111
- newInitArgs.push_back (repackOp);
112
- } else {
113
- oldBlockSize.push_back ({});
114
- newBlockSize.push_back ({});
115
- newInitArgs.push_back (arg);
116
- }
112
+ auto oldSize = defOp ? defOp.getInnerBlocksAttr () : DenseI64ArrayAttr ();
113
+ auto newSize = defOp ? getValue (blockArg) : DenseI64ArrayAttr ();
114
+ auto newArg =
115
+ defOp && newSize ? addUnpackAndPackOps (arg, newSize, rewriter) : arg;
116
+ oldBlockSizes.push_back (oldSize);
117
+ newBlockSizes.push_back (newSize);
118
+ newInitArgs.push_back (newArg);
119
+ changed |= (newArg.getType () != blockArg.getType ());
117
120
}
118
121
122
+ if (!changed)
123
+ return mlir::failure ();
124
+
119
125
auto newOp = rewriter.create <mlir::scf::ForOp>(
120
126
op.getLoc (), adaptor.getLowerBound (), adaptor.getUpperBound (),
121
127
adaptor.getStep (), newInitArgs);
@@ -124,15 +130,19 @@ struct SCFForOpPattern
124
130
mlir::Block *newBlock = newOp.getBody ();
125
131
llvm::SmallVector<mlir::Value> newArguments;
126
132
auto numCtrlOprs = newOp.getNumInductionVars ();
133
+ // remove the terminator of the new block
134
+ if (newBlock->mightHaveTerminator ())
135
+ rewriter.eraseOp (newBlock->getTerminator ());
136
+
127
137
// add UnpackOp and PackOp pairs to the block arguments
128
138
// if the corresponding init arg is repacked, such that
129
139
// the old unpack op using it in the body will be folded
130
140
for (auto [i, arg] : llvm::enumerate (newBlock->getArguments ())) {
131
- if (i < numCtrlOprs || !oldBlockSize [i - numCtrlOprs]) {
141
+ if (i < numCtrlOprs || !oldBlockSizes [i - numCtrlOprs]) {
132
142
newArguments.push_back (arg);
133
143
} else {
134
144
auto repackOp =
135
- addUnpackAndPackOps (arg, oldBlockSize [i - numCtrlOprs], rewriter);
145
+ addUnpackAndPackOps (arg, oldBlockSizes [i - numCtrlOprs], rewriter);
136
146
newArguments.push_back (repackOp);
137
147
}
138
148
}
@@ -145,9 +155,9 @@ struct SCFForOpPattern
145
155
mlir::OpBuilder::InsertionGuard g (rewriter);
146
156
rewriter.startOpModification (yieldOp);
147
157
for (auto [i, v] : llvm::enumerate (yieldOp.getResults ())) {
148
- if (newBlockSize [i]) {
158
+ if (newBlockSizes [i]) {
149
159
rewriter.setInsertionPointAfter (v.getDefiningOp ());
150
- auto repack = addUnpackAndPackOps (v, newBlockSize [i], rewriter);
160
+ auto repack = addUnpackAndPackOps (v, newBlockSizes [i], rewriter);
151
161
yieldOp->setOperand (i, repack);
152
162
}
153
163
}
@@ -157,8 +167,8 @@ struct SCFForOpPattern
157
167
rewriter.setInsertionPointAfter (op);
158
168
llvm::SmallVector<mlir::Value> newValues;
159
169
for (auto [i, result] : llvm::enumerate (newOp->getResults ())) {
160
- if (newInitArgs [i]. getDefiningOp <xetile::TilePackOp>() ) {
161
- auto unpack = addUnpackAndPackOps (result, oldBlockSize [i], rewriter);
170
+ if (oldBlockSizes [i]) {
171
+ auto unpack = addUnpackAndPackOps (result, oldBlockSizes [i], rewriter);
162
172
newValues.push_back (unpack);
163
173
} else {
164
174
newValues.push_back (result);
@@ -247,6 +257,9 @@ struct UpdateTileOffsetOpPattern
247
257
::mlir::LogicalResult
248
258
matchAndRewrite (xetile::UpdateTileOffsetOp op, OpAdaptor adaptor,
249
259
OpPatternRewriter &rewriter) const override {
260
+ if (adaptor.getTile ().getType () == op.getResult ().getType ())
261
+ return mlir::failure ();
262
+
250
263
rewriter.replaceOpWithNewOp <xetile::UpdateTileOffsetOp>(
251
264
op, adaptor.getTile ().getType (), adaptor.getTile (),
252
265
adaptor.getOffsetX (), adaptor.getOffsetY ());
@@ -291,9 +304,9 @@ class XeTileBlockAligningPass : public imex::impl::XeTileBlockAligningBase<
291
304
// Use TopDown traversal order, and only look at existing ops
292
305
// to simpliy the code logic and speedup the pass
293
306
mlir::GreedyRewriteConfig config;
307
+ config.enableRegionSimplification = false ;
294
308
config.useTopDownTraversal = true ;
295
- config.maxIterations = 2 ;
296
- config.strictMode = GreedyRewriteStrictness::ExistingOps;
309
+ config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
297
310
if (failed (
298
311
applyPatternsAndFoldGreedily (mod, std::move (patterns), config))) {
299
312
return signalPassFailure ();
0 commit comments