@@ -99,23 +99,29 @@ struct SCFForOpPattern
9999 ::mlir::LogicalResult
100100 matchAndRewrite (mlir::scf::ForOp op, OpAdaptor adaptor,
101101 OpPatternRewriter &rewriter) const override {
102+ // we need to update the SCF ForOp if the types of its init arg values
103+ // do not match the types of the region iter args, or the init arg value
104+ // is defined by a TilePackOp. Otherwise we can skip the op.
105+ bool changed = false ;
102106 llvm::SmallVector<mlir::Value> newInitArgs;
103- llvm::SmallVector<mlir::DenseI64ArrayAttr> oldBlockSize;
104- llvm::SmallVector<mlir::DenseI64ArrayAttr> newBlockSize;
105- for (auto arg : adaptor.getInitArgs ()) {
107+ llvm::SmallVector<mlir::DenseI64ArrayAttr> oldBlockSizes;
108+ llvm::SmallVector<mlir::DenseI64ArrayAttr> newBlockSizes;
109+ for (auto [i, arg] : llvm::enumerate (adaptor.getInitArgs ())) {
110+ auto blockArg = op.getRegionIterArg (i);
106111 auto defOp = arg.getDefiningOp <xetile::TilePackOp>();
107- if (auto blockSize = getValue (defOp)) {
108- newBlockSize.push_back (blockSize);
109- oldBlockSize.push_back (defOp.getInnerBlocksAttr ());
110- auto repackOp = addUnpackAndPackOps (arg, blockSize, rewriter);
111- newInitArgs.push_back (repackOp);
112- } else {
113- oldBlockSize.push_back ({});
114- newBlockSize.push_back ({});
115- newInitArgs.push_back (arg);
116- }
112+ auto oldSize = defOp ? defOp.getInnerBlocksAttr () : DenseI64ArrayAttr ();
113+ auto newSize = defOp ? getValue (blockArg) : DenseI64ArrayAttr ();
114+ auto newArg =
115+ defOp && newSize ? addUnpackAndPackOps (arg, newSize, rewriter) : arg;
116+ oldBlockSizes.push_back (oldSize);
117+ newBlockSizes.push_back (newSize);
118+ newInitArgs.push_back (newArg);
119+ changed |= (newArg.getType () != blockArg.getType ());
117120 }
118121
122+ if (!changed)
123+ return mlir::failure ();
124+
119125 auto newOp = rewriter.create <mlir::scf::ForOp>(
120126 op.getLoc (), adaptor.getLowerBound (), adaptor.getUpperBound (),
121127 adaptor.getStep (), newInitArgs);
@@ -124,15 +130,19 @@ struct SCFForOpPattern
124130 mlir::Block *newBlock = newOp.getBody ();
125131 llvm::SmallVector<mlir::Value> newArguments;
126132 auto numCtrlOprs = newOp.getNumInductionVars ();
133+ // remove the terminator of the new block
134+ if (newBlock->mightHaveTerminator ())
135+ rewriter.eraseOp (newBlock->getTerminator ());
136+
127137 // add UnpackOp and PackOp pairs to the block arguments
128138 // if the corresponding init arg is repacked, such that
129139 // the old unpack op using it in the body will be folded
130140 for (auto [i, arg] : llvm::enumerate (newBlock->getArguments ())) {
131- if (i < numCtrlOprs || !oldBlockSize [i - numCtrlOprs]) {
141+ if (i < numCtrlOprs || !oldBlockSizes [i - numCtrlOprs]) {
132142 newArguments.push_back (arg);
133143 } else {
134144 auto repackOp =
135- addUnpackAndPackOps (arg, oldBlockSize [i - numCtrlOprs], rewriter);
145+ addUnpackAndPackOps (arg, oldBlockSizes [i - numCtrlOprs], rewriter);
136146 newArguments.push_back (repackOp);
137147 }
138148 }
@@ -145,9 +155,9 @@ struct SCFForOpPattern
145155 mlir::OpBuilder::InsertionGuard g (rewriter);
146156 rewriter.startOpModification (yieldOp);
147157 for (auto [i, v] : llvm::enumerate (yieldOp.getResults ())) {
148- if (newBlockSize [i]) {
158+ if (newBlockSizes [i]) {
149159 rewriter.setInsertionPointAfter (v.getDefiningOp ());
150- auto repack = addUnpackAndPackOps (v, newBlockSize [i], rewriter);
160+ auto repack = addUnpackAndPackOps (v, newBlockSizes [i], rewriter);
151161 yieldOp->setOperand (i, repack);
152162 }
153163 }
@@ -157,8 +167,8 @@ struct SCFForOpPattern
157167 rewriter.setInsertionPointAfter (op);
158168 llvm::SmallVector<mlir::Value> newValues;
159169 for (auto [i, result] : llvm::enumerate (newOp->getResults ())) {
160- if (newInitArgs [i]. getDefiningOp <xetile::TilePackOp>() ) {
161- auto unpack = addUnpackAndPackOps (result, oldBlockSize [i], rewriter);
170+ if (oldBlockSizes [i]) {
171+ auto unpack = addUnpackAndPackOps (result, oldBlockSizes [i], rewriter);
162172 newValues.push_back (unpack);
163173 } else {
164174 newValues.push_back (result);
@@ -247,6 +257,9 @@ struct UpdateTileOffsetOpPattern
247257 ::mlir::LogicalResult
248258 matchAndRewrite (xetile::UpdateTileOffsetOp op, OpAdaptor adaptor,
249259 OpPatternRewriter &rewriter) const override {
260+ if (adaptor.getTile ().getType () == op.getResult ().getType ())
261+ return mlir::failure ();
262+
250263 rewriter.replaceOpWithNewOp <xetile::UpdateTileOffsetOp>(
251264 op, adaptor.getTile ().getType (), adaptor.getTile (),
252265 adaptor.getOffsetX (), adaptor.getOffsetY ());
@@ -291,9 +304,9 @@ class XeTileBlockAligningPass : public imex::impl::XeTileBlockAligningBase<
291304 // Use TopDown traversal order, and only look at existing ops
292305 // to simpliy the code logic and speedup the pass
293306 mlir::GreedyRewriteConfig config;
307+ config.enableRegionSimplification = false ;
294308 config.useTopDownTraversal = true ;
295- config.maxIterations = 2 ;
296- config.strictMode = GreedyRewriteStrictness::ExistingOps;
309+ config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
297310 if (failed (
298311 applyPatternsAndFoldGreedily (mod, std::move (patterns), config))) {
299312 return signalPassFailure ();
0 commit comments