diff --git a/src/Dialect/ONNX/Transforms/Decompose.cpp b/src/Dialect/ONNX/Transforms/Decompose.cpp index 658689b059..098a91a5c6 100644 --- a/src/Dialect/ONNX/Transforms/Decompose.cpp +++ b/src/Dialect/ONNX/Transforms/Decompose.cpp @@ -22,10 +22,9 @@ #include -#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "src/Compiler/CompilerOptions.hpp" #include "llvm/Support/Debug.h" @@ -1179,110 +1178,16 @@ void DecomposeONNXToONNXPass::runOnOperation() { func::FuncOp function = getOperation(); MLIRContext *context = &getContext(); - ConversionTarget target(getContext()); - target.addLegalDialect(); - - // These ops will be decomposed into other ONNX ops. Hence, they will not be - // available after this pass. - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - - if (!onnx_mlir::decomposeOpsInONNX.empty()) { - for (const auto &op : onnx_mlir::decomposeOpsInONNX) { - if (op == "HardSwish") { - target.addIllegalOp(); - } - } - } - target.addDynamicallyLegalOp([](ONNXEinsumOp op) { - return !onnx_mlir::DecomposeEinsumPattern::isDecomposable(op); - }); - - target.addDynamicallyLegalOp([](ONNXConcatOp op) { - ONNXShapeOp shapeOp; - ONNXTransposeOp transposeOp; - return !isConcatFuseMatched(op, shapeOp, transposeOp); - }); - - target.addDynamicallyLegalOp([](ONNXSequenceAtOp op) { - return !onnx_mlir::canSequenceAtBeReplaced(op.getResult()); - }); - - // Rewrite ONNXConstantOp with scalar values into the one using ElementAttrs. - target.addDynamicallyLegalOp([](ONNXConstantOp op) { - return !(op.getValueFloatAttr() || op.getValueFloatsAttr() || - op.getValueIntAttr() || op.getValueIntsAttr() || - op.getValueStringAttr() || op.getValueStringsAttr()); - }); - - // Decompose CustomOp FusedMatMul introduced by onnxruntime: - // https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.FusedMatMul - target.addDynamicallyLegalOp([](ONNXCustomOp op) { - int64_t rankA, rankB; - FloatAttr alpha; - return !CustomOpFuseMatMulPattern::isCustomOpFusedMatMulMatched( - op, alpha, rankA, rankB); - }); - -#ifdef ONNX_MLIR_ENABLE_STABLEHLO - // ONNXtoStablehlo pass has own rewriting for ConvTranspose Op using - // stablehlo ops. To avoid conflict with it, decomposing for ConvTranspose - // is disabled when the target is stablehlo. - if (this->target != "stablehlo") { -#endif - target.addDynamicallyLegalOp( - [](ONNXConvTransposeOp op) { - return !onnx_mlir::shouldDecomposeConvTransposeOp(op); - }); -#ifdef ONNX_MLIR_ENABLE_STABLEHLO - } -#endif - RewritePatternSet patterns(context); onnx_mlir::getDecomposeONNXToONNXPatterns(patterns); patterns.insert(context); #ifdef ONNX_MLIR_ENABLE_STABLEHLO if (this->target == "stablehlo") { populateDecomposingONNXBeforeStablehloPatterns(patterns, context); - target.addIllegalOp(); } #endif - if (failed(applyPartialConversion(function, target, std::move(patterns)))) + if (failed(applyPatternsGreedily(function, std::move(patterns)))) signalPassFailure(); }