From e2883899f45c4da6d590d03d1ba7e892fe911faa Mon Sep 17 00:00:00 2001 From: Jonas Rickert Date: Tue, 1 Jul 2025 09:11:23 -0600 Subject: [PATCH] Replace usage of applyPartialConversion with applyPatternsGreedily in Decompose.cpp applyPatternsGreedily has the advantage that operation do not need to be manual marked as "illegal". The greedy rewriter is also generally simpler than partial-conversion. Signed-off-by: Jonas Rickert --- src/Dialect/ONNX/Transforms/Decompose.cpp | 99 +---------------------- 1 file changed, 2 insertions(+), 97 deletions(-) diff --git a/src/Dialect/ONNX/Transforms/Decompose.cpp b/src/Dialect/ONNX/Transforms/Decompose.cpp index 658689b059..098a91a5c6 100644 --- a/src/Dialect/ONNX/Transforms/Decompose.cpp +++ b/src/Dialect/ONNX/Transforms/Decompose.cpp @@ -22,10 +22,9 @@ #include -#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "src/Compiler/CompilerOptions.hpp" #include "llvm/Support/Debug.h" @@ -1179,110 +1178,16 @@ void DecomposeONNXToONNXPass::runOnOperation() { func::FuncOp function = getOperation(); MLIRContext *context = &getContext(); - ConversionTarget target(getContext()); - target.addLegalDialect(); - - // These ops will be decomposed into other ONNX ops. Hence, they will not be - // available after this pass. - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - - if (!onnx_mlir::decomposeOpsInONNX.empty()) { - for (const auto &op : onnx_mlir::decomposeOpsInONNX) { - if (op == "HardSwish") { - target.addIllegalOp(); - } - } - } - target.addDynamicallyLegalOp([](ONNXEinsumOp op) { - return !onnx_mlir::DecomposeEinsumPattern::isDecomposable(op); - }); - - target.addDynamicallyLegalOp([](ONNXConcatOp op) { - ONNXShapeOp shapeOp; - ONNXTransposeOp transposeOp; - return !isConcatFuseMatched(op, shapeOp, transposeOp); - }); - - target.addDynamicallyLegalOp([](ONNXSequenceAtOp op) { - return !onnx_mlir::canSequenceAtBeReplaced(op.getResult()); - }); - - // Rewrite ONNXConstantOp with scalar values into the one using ElementAttrs. - target.addDynamicallyLegalOp([](ONNXConstantOp op) { - return !(op.getValueFloatAttr() || op.getValueFloatsAttr() || - op.getValueIntAttr() || op.getValueIntsAttr() || - op.getValueStringAttr() || op.getValueStringsAttr()); - }); - - // Decompose CustomOp FusedMatMul introduced by onnxruntime: - // https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.FusedMatMul - target.addDynamicallyLegalOp([](ONNXCustomOp op) { - int64_t rankA, rankB; - FloatAttr alpha; - return !CustomOpFuseMatMulPattern::isCustomOpFusedMatMulMatched( - op, alpha, rankA, rankB); - }); - -#ifdef ONNX_MLIR_ENABLE_STABLEHLO - // ONNXtoStablehlo pass has own rewriting for ConvTranspose Op using - // stablehlo ops. To avoid conflict with it, decomposing for ConvTranspose - // is disabled when the target is stablehlo. - if (this->target != "stablehlo") { -#endif - target.addDynamicallyLegalOp( - [](ONNXConvTransposeOp op) { - return !onnx_mlir::shouldDecomposeConvTransposeOp(op); - }); -#ifdef ONNX_MLIR_ENABLE_STABLEHLO - } -#endif - RewritePatternSet patterns(context); onnx_mlir::getDecomposeONNXToONNXPatterns(patterns); patterns.insert(context); #ifdef ONNX_MLIR_ENABLE_STABLEHLO if (this->target == "stablehlo") { populateDecomposingONNXBeforeStablehloPatterns(patterns, context); - target.addIllegalOp(); } #endif - if (failed(applyPartialConversion(function, target, std::move(patterns)))) + if (failed(applyPatternsGreedily(function, std::move(patterns)))) signalPassFailure(); }