1717#include " mlir/Dialect/ArmSME/Transforms/Passes.h"
1818#include " mlir/Dialect/ArmSME/Utils/Utils.h"
1919#include " mlir/Dialect/Func/IR/FuncOps.h"
20- #include " mlir/Dialect/Func/Transforms/OneToNFuncConversions .h"
20+ #include " mlir/Dialect/Func/Transforms/FuncConversions .h"
2121#include " mlir/Dialect/Index/IR/IndexDialect.h"
2222#include " mlir/Dialect/Index/IR/IndexOps.h"
2323#include " mlir/Dialect/MemRef/IR/MemRef.h"
2424#include " mlir/Dialect/SCF/IR/SCF.h"
2525#include " mlir/Dialect/SCF/Transforms/Patterns.h"
2626#include " mlir/Dialect/Utils/IndexingUtils.h"
2727#include " mlir/Dialect/Vector/Utils/VectorUtils.h"
28- #include " mlir/Transforms/OneToNTypeConversion.h"
28+ #include " mlir/Transforms/DialectConversion.h"
29+ #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
2930
3031#define DEBUG_TYPE " arm-sme-vector-legalization"
3132
@@ -172,12 +173,12 @@ int getNumberOfSMETilesForVectorType(VectorType type) {
172173// / Legalize `arith.constant dense<value>` splat operations to fit within SME
173174// / tiles by decomposing them into tile-sized operations.
174175struct LegalizeArithConstantOpsByDecomposition
175- : public OneToNOpConversionPattern <arith::ConstantOp> {
176- using OneToNOpConversionPattern::OneToNOpConversionPattern ;
176+ : public OpConversionPattern <arith::ConstantOp> {
177+ using OpConversionPattern::OpConversionPattern ;
177178
178179 LogicalResult
179180 matchAndRewrite (arith::ConstantOp constantOp, OpAdaptor adaptor,
180- OneToNPatternRewriter &rewriter) const override {
181+ ConversionPatternRewriter &rewriter) const override {
181182 auto vectorType = dyn_cast<VectorType>(constantOp.getType ());
182183 auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr ());
183184 if (!vectorType || !denseAttr || !denseAttr.isSplat ())
@@ -191,8 +192,8 @@ struct LegalizeArithConstantOpsByDecomposition
191192 auto tileCount = getNumberOfSMETilesForVectorType (vectorType);
192193 auto tileSplat = rewriter.create <arith::ConstantOp>(
193194 constantOp.getLoc (), denseAttr.resizeSplat (smeTileType));
194- rewriter. replaceOp (constantOp, SmallVector<Value>(tileCount, tileSplat),
195- adaptor. getResultMapping () );
195+ SmallVector<Value> repl (tileCount, tileSplat);
196+ rewriter. replaceOpWithMultiple (constantOp, {repl} );
196197
197198 return success ();
198199 }
@@ -201,12 +202,13 @@ struct LegalizeArithConstantOpsByDecomposition
201202// / Legalize `vector.outerproduct` operations to fit within SME tiles by
202203// / decomposing them into tile-sized operations.
203204struct LegalizeVectorOuterProductOpsByDecomposition
204- : public OneToNOpConversionPattern <vector::OuterProductOp> {
205- using OneToNOpConversionPattern::OneToNOpConversionPattern ;
205+ : public OpConversionPattern <vector::OuterProductOp> {
206+ using OpConversionPattern::OpConversionPattern ;
206207
207208 LogicalResult
208- matchAndRewrite (vector::OuterProductOp outerProductOp, OpAdaptor adaptor,
209- OneToNPatternRewriter &rewriter) const override {
209+ matchAndRewrite (vector::OuterProductOp outerProductOp,
210+ OneToNOpAdaptor adaptor,
211+ ConversionPatternRewriter &rewriter) const override {
210212 auto vectorType = outerProductOp.getResultVectorType ();
211213 if (!isMultipleOfSMETileVectorType (vectorType))
212214 return rewriter.notifyMatchFailure (outerProductOp,
@@ -219,6 +221,7 @@ struct LegalizeVectorOuterProductOpsByDecomposition
219221 auto maskOp = outerProductOp.getMaskingOp ();
220222 mask = maskOp.getMask ();
221223 rootOp = maskOp;
224+ rewriter.setInsertionPoint (rootOp);
222225 }
223226
224227 if (!isSupportedMaskOp (mask))
@@ -248,7 +251,7 @@ struct LegalizeVectorOuterProductOpsByDecomposition
248251 resultSMETiles.push_back (maskedOuterProduct->getResult (0 ));
249252 }
250253
251- rewriter.replaceOp (rootOp, resultSMETiles, adaptor. getResultMapping () );
254+ rewriter.replaceOpWithMultiple (rootOp, { resultSMETiles} );
252255 return success ();
253256 }
254257};
@@ -259,12 +262,12 @@ struct LegalizeVectorOuterProductOpsByDecomposition
259262// (invalid). This pattern matches on `vector.mask` then calls into the
260263// `vector.outerproduct` pattern to work around this issue.
261264struct LegalizeMaskedVectorOuterProductOpsByDecomposition
262- : public OneToNOpConversionPattern <vector::MaskOp> {
263- using OneToNOpConversionPattern::OneToNOpConversionPattern ;
265+ : public OpConversionPattern <vector::MaskOp> {
266+ using OpConversionPattern::OpConversionPattern ;
264267
265268 LogicalResult
266- matchAndRewrite (vector::MaskOp maskOp, OpAdaptor adaptor,
267- OneToNPatternRewriter &rewriter) const override {
269+ matchAndRewrite (vector::MaskOp maskOp, OneToNOpAdaptor adaptor,
270+ ConversionPatternRewriter &rewriter) const override {
268271 if (auto outerProductOp = llvm::dyn_cast_or_null<vector::OuterProductOp>(
269272 maskOp.getMaskableOp ())) {
270273 LegalizeVectorOuterProductOpsByDecomposition pattern (*getTypeConverter (),
@@ -279,12 +282,12 @@ struct LegalizeMaskedVectorOuterProductOpsByDecomposition
279282// / Legalize `vector.transfer_read` operations to fit within SME tiles by
280283// / decomposing them into tile-sized operations.
281284struct LegalizeTransferReadOpsByDecomposition
282- : public OneToNOpConversionPattern <vector::TransferReadOp> {
283- using OneToNOpConversionPattern::OneToNOpConversionPattern ;
285+ : public OpConversionPattern <vector::TransferReadOp> {
286+ using OpConversionPattern::OpConversionPattern ;
284287
285288 LogicalResult
286- matchAndRewrite (vector::TransferReadOp readOp, OpAdaptor adaptor,
287- OneToNPatternRewriter &rewriter) const override {
289+ matchAndRewrite (vector::TransferReadOp readOp, OneToNOpAdaptor adaptor,
290+ ConversionPatternRewriter &rewriter) const override {
288291 auto vectorType = readOp.getVectorType ();
289292 if (!isMultipleOfSMETileVectorType (vectorType))
290293 return rewriter.notifyMatchFailure (readOp,
@@ -319,20 +322,20 @@ struct LegalizeTransferReadOpsByDecomposition
319322 resultSMETiles.push_back (smeRead);
320323 }
321324
322- rewriter.replaceOp (readOp, resultSMETiles, adaptor. getResultMapping () );
325+ rewriter.replaceOpWithMultiple (readOp, { resultSMETiles} );
323326 return success ();
324327 }
325328};
326329
327330// / Legalize `vector.transfer_write` operations to fit within SME tiles by
328331// / decomposing them into tile-sized operations.
329332struct LegalizeTransferWriteOpsByDecomposition
330- : public OneToNOpConversionPattern <vector::TransferWriteOp> {
331- using OneToNOpConversionPattern::OneToNOpConversionPattern ;
333+ : public OpConversionPattern <vector::TransferWriteOp> {
334+ using OpConversionPattern::OpConversionPattern ;
332335
333336 LogicalResult
334- matchAndRewrite (vector::TransferWriteOp writeOp, OpAdaptor adaptor,
335- OneToNPatternRewriter &rewriter) const override {
337+ matchAndRewrite (vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
338+ ConversionPatternRewriter &rewriter) const override {
336339 auto vectorType = writeOp.getVectorType ();
337340 if (!isMultipleOfSMETileVectorType (vectorType))
338341 return rewriter.notifyMatchFailure (writeOp,
@@ -409,12 +412,12 @@ struct LegalizeTransferWriteOpsByDecomposition
409412// / }
410413// / ```
411414struct LegalizeMultiTileTransferWriteAsStoreLoop
412- : public OneToNOpConversionPattern <vector::TransferWriteOp> {
413- using OneToNOpConversionPattern::OneToNOpConversionPattern ;
415+ : public OpConversionPattern <vector::TransferWriteOp> {
416+ using OpConversionPattern::OpConversionPattern ;
414417
415418 LogicalResult
416- matchAndRewrite (vector::TransferWriteOp writeOp, OpAdaptor adaptor,
417- OneToNPatternRewriter &rewriter) const override {
419+ matchAndRewrite (vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
420+ ConversionPatternRewriter &rewriter) const override {
418421 if (writeOp.hasPureTensorSemantics ())
419422 return rewriter.notifyMatchFailure (
420423 writeOp, " TODO: tensor semantics are unsupported" );
@@ -936,10 +939,16 @@ struct VectorLegalizationPass
936939 return success ();
937940 });
938941
939- patterns.add <FoldExtractFromVectorOfSMELikeCreateMasks,
940- LiftIllegalVectorTransposeToMemory,
941- ConvertIllegalShapeCastOpsToTransposes,
942- LowerIllegalTransposeStoreViaZA>(context);
942+ // Apply preprocessing patterns.
943+ RewritePatternSet rewritePatterns (context);
944+ rewritePatterns.add <FoldExtractFromVectorOfSMELikeCreateMasks,
945+ LiftIllegalVectorTransposeToMemory,
946+ ConvertIllegalShapeCastOpsToTransposes,
947+ LowerIllegalTransposeStoreViaZA>(context);
948+ if (failed (
949+ applyPatternsGreedily (getOperation (), std::move (rewritePatterns))))
950+ return signalPassFailure ();
951+
943952 // Note: These two patterns are added with a high benefit to ensure:
944953 // - Masked outer products are handled before unmasked ones
945954 // - Multi-tile writes are lowered as a store loop (if possible)
@@ -950,11 +959,20 @@ struct VectorLegalizationPass
950959 LegalizeVectorOuterProductOpsByDecomposition,
951960 LegalizeTransferReadOpsByDecomposition,
952961 LegalizeTransferWriteOpsByDecomposition>(converter, context);
953- populateFuncTypeConversionPatterns (converter, patterns);
954- scf::populateSCFStructuralOneToNTypeConversions (converter, patterns);
955-
956- if (failed (applyPartialOneToNConversion (getOperation (), converter,
957- std::move (patterns))))
962+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
963+ converter);
964+ populateCallOpTypeConversionPattern (patterns, converter);
965+ populateReturnOpTypeConversionPattern (patterns, converter);
966+ scf::populateSCFStructuralTypeConversions (converter, patterns);
967+
968+ ConversionTarget target (getContext ());
969+ target.markUnknownOpDynamicallyLegal (
970+ [&](Operation *op) { return converter.isLegal (op); });
971+ target.addDynamicallyLegalOp <func::FuncOp>([&](func::FuncOp op) {
972+ return converter.isSignatureLegal (op.getFunctionType ());
973+ });
974+ if (failed (applyPartialConversion (getOperation (), target,
975+ std::move (patterns))))
958976 return signalPassFailure ();
959977 }
960978};
0 commit comments