Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 2 additions & 97 deletions src/Dialect/ONNX/Transforms/Decompose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@

#include <numeric>

#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "src/Compiler/CompilerOptions.hpp"
#include "llvm/Support/Debug.h"

Expand Down Expand Up @@ -1179,110 +1178,16 @@ void DecomposeONNXToONNXPass::runOnOperation() {
func::FuncOp function = getOperation();
MLIRContext *context = &getContext();

ConversionTarget target(getContext());
target.addLegalDialect<ONNXDialect, arith::ArithDialect, func::FuncDialect>();

// These ops will be decomposed into other ONNX ops. Hence, they will not be
// available after this pass.
target.addIllegalOp<ONNXCastLikeOp>();
target.addIllegalOp<ONNXClipV11Op>();
target.addIllegalOp<ONNXClipV12Op>();
target.addIllegalOp<ONNXClipV6Op>();
target.addIllegalOp<ONNXConstantOfShapeOp>();
target.addIllegalOp<ONNXDFTV17Op>();
target.addIllegalOp<ONNXGridSampleV16Op>();
target.addIllegalOp<ONNXLogSoftmaxOp>();
target.addIllegalOp<ONNXPadV11Op>();
target.addIllegalOp<ONNXPadV13Op>();
target.addIllegalOp<ONNXPadV18Op>();
target.addIllegalOp<ONNXPadV2Op>();
target.addIllegalOp<ONNXReduceL1Op>();
target.addIllegalOp<ONNXReduceL1V13Op>();
target.addIllegalOp<ONNXReduceL2Op>();
target.addIllegalOp<ONNXReduceL2V13Op>();
target.addIllegalOp<ONNXReduceLogSumExpOp>();
target.addIllegalOp<ONNXReduceLogSumOp>();
target.addIllegalOp<ONNXReduceMaxV18Op>();
target.addIllegalOp<ONNXReduceMinV18Op>();
target.addIllegalOp<ONNXReduceSumSquareOp>();
target.addIllegalOp<ONNXResizeV10Op>();
target.addIllegalOp<ONNXResizeV11Op>();
target.addIllegalOp<ONNXResizeV13Op>();
target.addIllegalOp<ONNXResizeV18Op>();
target.addIllegalOp<ONNXScalerOp>();
target.addIllegalOp<ONNXScatterOp>();
target.addIllegalOp<ONNXSequenceConstructOp>();
target.addIllegalOp<ONNXSoftmaxCrossEntropyLossOp>();
target.addIllegalOp<ONNXSplitV11Op>();
target.addIllegalOp<ONNXSplitV13Op>();
target.addIllegalOp<ONNXSqueezeV11Op>();
target.addIllegalOp<ONNXSumOp>();
target.addIllegalOp<ONNXUnsqueezeV11Op>();
target.addIllegalOp<ONNXUpsampleOp>();
target.addIllegalOp<ONNXUpsampleV7Op>();

if (!onnx_mlir::decomposeOpsInONNX.empty()) {
for (const auto &op : onnx_mlir::decomposeOpsInONNX) {
if (op == "HardSwish") {
target.addIllegalOp<ONNXHardSwishOp>();
}
}
}
target.addDynamicallyLegalOp<ONNXEinsumOp>([](ONNXEinsumOp op) {
return !onnx_mlir::DecomposeEinsumPattern::isDecomposable(op);
});

target.addDynamicallyLegalOp<ONNXConcatOp>([](ONNXConcatOp op) {
ONNXShapeOp shapeOp;
ONNXTransposeOp transposeOp;
return !isConcatFuseMatched(op, shapeOp, transposeOp);
});

target.addDynamicallyLegalOp<ONNXSequenceAtOp>([](ONNXSequenceAtOp op) {
return !onnx_mlir::canSequenceAtBeReplaced(op.getResult());
});

// Rewrite ONNXConstantOp with scalar values into the one using ElementAttrs.
target.addDynamicallyLegalOp<ONNXConstantOp>([](ONNXConstantOp op) {
return !(op.getValueFloatAttr() || op.getValueFloatsAttr() ||
op.getValueIntAttr() || op.getValueIntsAttr() ||
op.getValueStringAttr() || op.getValueStringsAttr());
});

// Decompose CustomOp FusedMatMul introduced by onnxruntime:
// https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.FusedMatMul
target.addDynamicallyLegalOp<ONNXCustomOp>([](ONNXCustomOp op) {
int64_t rankA, rankB;
FloatAttr alpha;
return !CustomOpFuseMatMulPattern::isCustomOpFusedMatMulMatched(
op, alpha, rankA, rankB);
});

#ifdef ONNX_MLIR_ENABLE_STABLEHLO
// ONNXtoStablehlo pass has own rewriting for ConvTranspose Op using
// stablehlo ops. To avoid conflict with it, decomposing for ConvTranspose
// is disabled when the target is stablehlo.
if (this->target != "stablehlo") {
#endif
target.addDynamicallyLegalOp<ONNXConvTransposeOp>(
[](ONNXConvTransposeOp op) {
return !onnx_mlir::shouldDecomposeConvTransposeOp(op);
});
#ifdef ONNX_MLIR_ENABLE_STABLEHLO
}
#endif

RewritePatternSet patterns(context);
onnx_mlir::getDecomposeONNXToONNXPatterns(patterns);
patterns.insert<ReplaceCastLikeByCastPattern>(context);
#ifdef ONNX_MLIR_ENABLE_STABLEHLO
if (this->target == "stablehlo") {
populateDecomposingONNXBeforeStablehloPatterns(patterns, context);
target.addIllegalOp<ONNXSoftmaxOp>();
}
#endif

if (failed(applyPartialConversion(function, target, std::move(patterns))))
if (failed(applyPatternsGreedily(function, std::move(patterns))))
signalPassFailure();
}

Expand Down
Loading