diff --git a/src/Compiler/CompilerPasses.cpp b/src/Compiler/CompilerPasses.cpp index 50b6414151..e2a8625196 100644 --- a/src/Compiler/CompilerPasses.cpp +++ b/src/Compiler/CompilerPasses.cpp @@ -33,6 +33,7 @@ #include "src/Compiler/CompilerOptions.hpp" #include "src/Compiler/CompilerPasses.hpp" +#include "src/Compiler/DisposableGarbageCollector.hpp" #include "src/Compiler/OnnxToMlirPasses.hpp" #include "src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp" #include "src/Dialect/Mlir/VectorMachineSupport.hpp" @@ -66,6 +67,139 @@ void configurePasses() { !disableSimdOption); } +void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU, + bool donotScrubDisposableElementsAttr, OnnxToMlirOptions opts) { + // This is a transition from previous static passes to full dynamic passes + // Static passes are kept and the dynamic pass is added as IF-THEN + // with the static iteration. + // The reasons are + // 1. The debug flag, --print-ir-after/befor-all, can display IR for each + // static pass, but the dynamic pipeline will be viewed as one. MLIR + // may have solution that I am not aware of yet. + // 2. Easy to compare two approaches. + // In future, only the dynamic pass, ONNXOpTransformPass, will be used for + // this function. + + if (!donotScrubDisposableElementsAttr) + pm.addInstrumentation( + std::make_unique(pm.getContext())); + + // Decompose first. Eliminates some unsupported ops without shape inference. + pm.addNestedPass(onnx_mlir::createDecomposeONNXToONNXPass( + /*target=*/"", opts.enableConvTransposeDecompose, + opts.enableConvTransposeDecomposeToPhasedConv, + opts.enableConvTranspose1dDecomposeToPhasedConv)); + if (!disableRecomposeOption) + pm.addNestedPass(onnx_mlir::createRecomposeONNXToONNXPass()); + if (enableONNXHybridPass) { + pm.addNestedPass(onnx_mlir::createONNXHybridTransformPass( + !disableRecomposeOption, opts.enableQuarkQuantizedLegalization, + opts.enableConvTransposeDecompose, + opts.enableConvTransposeDecomposeToPhasedConv, + opts.enableConvTranspose1dDecomposeToPhasedConv)); + // Convolution Optimization for CPU: enable when there are no accelerators. + if (targetCPU && enableConvOptPass) { + pm.addNestedPass(onnx_mlir::createConvOptONNXToONNXPass( + enableSimdDataLayout && !disableSimdOption)); + pm.addNestedPass( + onnx_mlir::createONNXHybridTransformPass(!disableRecomposeOption, + /*enableQuarkQuantizedOpsLegalization=*/false, + opts.enableConvTransposeDecompose, + opts.enableConvTransposeDecomposeToPhasedConv, + opts.enableConvTranspose1dDecomposeToPhasedConv)); + } + } else { + pm.addNestedPass(onnx_mlir::createShapeInferencePass()); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addNestedPass(onnx_mlir::createShapeInferencePass()); + // Convolution Optimization for CPU: enable when there are no accelerators. + if (targetCPU && enableConvOptPass) { + pm.addNestedPass(onnx_mlir::createConvOptONNXToONNXPass( + enableSimdDataLayout && !disableSimdOption)); + pm.addNestedPass(onnx_mlir::createShapeInferencePass()); + } + pm.addNestedPass( + onnx_mlir::createLegalizeQuarkQuantizedOpsPass()); + pm.addNestedPass(onnx_mlir::createConstPropONNXToONNXPass()); + if (onnxOpTransformThreshold > 0) { + // Dynamic iterate in ONNXOpTransformPass + pm.addPass(onnx_mlir::createONNXOpTransformPass(onnxOpTransformThreshold, + onnxOpTransformReport, targetCPU, + enableSimdDataLayout && !disableSimdOption, enableConvOptPass, + !disableRecomposeOption)); + } else { + // Statically add extra passes + for (int i = 0; i < repeatOnnxTransform; i++) { + pm.addPass(mlir::createCanonicalizerPass()); + pm.addNestedPass(onnx_mlir::createShapeInferencePass()); + pm.addNestedPass( + onnx_mlir::createConstPropONNXToONNXPass()); + } + } + } + + // Simplify shape-related ops. + pm.addPass(onnx_mlir::createSimplifyShapeRelatedOpsPass( + opts.enableQuarkQuantizedLegalization)); + + // Passes for removing redundant concat, slice and cast QDQ Ops + if (opts.enableRemoveDqQOp) + pm.addPass(createQDQOptONNXToONNXPass()); + if (opts.enableRemoveBinary) + pm.addPass(createFoldDQBinaryQPass()); + + // One more call to ONNX shape inference/canonicalization/... to update + // shape if possible. + if (enableONNXHybridPass) { + pm.addNestedPass(onnx_mlir::createONNXHybridTransformPass( + !disableRecomposeOption, opts.enableQuarkQuantizedLegalization, + opts.enableConvTransposeDecompose, + opts.enableConvTransposeDecomposeToPhasedConv, + opts.enableConvTranspose1dDecomposeToPhasedConv)); + } else { + pm.addNestedPass(onnx_mlir::createShapeInferencePass()); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addNestedPass(onnx_mlir::createShapeInferencePass()); + } + + // Replace ONNXReturnOp with func::ReturnOp. + pm.addPass(onnx_mlir::createStandardFuncReturnPass()); + + // Clean dead code. + pm.addPass(mlir::createSymbolDCEPass()); + + // Replace every DisposableElementsAttr with DenseElementsAttr. + if (!donotScrubDisposableElementsAttr) + pm.addPass(createScrubDisposablePass()); + + // Set onnx_node_name if it is missing. Keep this pass at the end of this + // function and just before instrumentation. + pm.addPass(createSetONNXNodeNamePass()); + + // Add instrumentation for Onnx Ops + // Keep this pass at the end of this function. + unsigned instrumentActions = instrumentControlBits; + if (profileIR == onnx_mlir::ProfileIRs::Onnx) { + instrumentStage = onnx_mlir::InstrumentStages::Onnx; + instrumentOps = "onnx.*"; + // Enable the first three bits for InstrumentBeforOp, InstrumentAfterOp + // and InstrumentReportTime. Disable the last bit for + // InstrumentReportMemory because of its big overhead. Users can + // optionally enable the last bit by using + // --InstrumentReportMemory option. + instrumentActions |= (1 << 3) - 1; + } + if (instrumentStage == onnx_mlir::InstrumentStages::Onnx) + pm.addNestedPass( + onnx_mlir::createInstrumentPass(instrumentOps, instrumentActions)); + // Print Signatures of each op at runtime if enabled. Should not run + // signature and instrument passes at the same time as time may include printf + // overheads. + if (instrumentSignatures != "NONE" || instrumentOnnxNode != "NONE") + pm.addNestedPass(onnx_mlir::createInstrumentONNXSignaturePass( + instrumentSignatures, instrumentOnnxNode)); +} + void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE, std::string ONNXOpsStatFormat) { if (enableCSE) diff --git a/src/Compiler/CompilerPasses.hpp b/src/Compiler/CompilerPasses.hpp index 6f7439c681..c05b6ec564 100644 --- a/src/Compiler/CompilerPasses.hpp +++ b/src/Compiler/CompilerPasses.hpp @@ -25,6 +25,19 @@ namespace onnx_mlir { // Configures passes up front based on command line options. void configurePasses(); +/* +struct OnnxToMlirOptions { + bool enableQuarkQuantizedLegalization = false; + bool enableConvTransposeDecompose = false; + bool enableConvTransposeDecomposeToPhasedConv = false; + bool enableConvTranspose1dDecomposeToPhasedConv = false; + bool enableRemoveDqQOp = false; + bool enableRemoveBinary = false; +}; + +void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU, + bool donotScrubDisposableElementsAttr = false, OnnxToMlirOptions opts = {}); +*/ void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE, std::string ONNXOpsStatFilename); void addKrnlToAffinePasses(mlir::PassManager &pm); diff --git a/src/Compiler/OnnxToMlirPasses.hpp b/src/Compiler/OnnxToMlirPasses.hpp index 1c007ed6f7..b33019beaf 100644 --- a/src/Compiler/OnnxToMlirPasses.hpp +++ b/src/Compiler/OnnxToMlirPasses.hpp @@ -18,6 +18,7 @@ struct OnnxToMlirOptions { bool enableConvTranspose1dDecomposeToPhasedConv = false; bool enableRemoveDqQOp = true; bool enableRemoveDqQAroundOp = true; + bool enableRemoveBinary = true; bool disableRecomposeOption = false; bool enableONNXHybridPass = true; diff --git a/src/Dialect/ONNX/Transforms/CMakeLists.txt b/src/Dialect/ONNX/Transforms/CMakeLists.txt index d647901b54..2a82efa8c3 100644 --- a/src/Dialect/ONNX/Transforms/CMakeLists.txt +++ b/src/Dialect/ONNX/Transforms/CMakeLists.txt @@ -8,6 +8,12 @@ add_onnx_mlir_rewriter(DecomposeConvTranspose1dPhased) add_onnx_mlir_rewriter(ConstProp) add_onnx_mlir_rewriter(ConvOpt) +add_onnx_mlir_rewriter(QDQAroundOpOpt) +add_onnx_mlir_rewriter(QDQOpt) +add_onnx_mlir_rewriter(DQBinaryQOpt) + + + add_onnx_mlir_library(OMShapeInference ShapeInference.cpp @@ -44,6 +50,7 @@ add_onnx_mlir_library(OMONNXRewrite ConstProp.cpp QDQAroundOpOpt.cpp QDQOpt.cpp + DQBinaryQOpt.cpp ConvOpt.cpp Decompose.cpp DecomposeEinsum.cpp diff --git a/src/Dialect/ONNX/Transforms/DQBinaryQOpt.cpp b/src/Dialect/ONNX/Transforms/DQBinaryQOpt.cpp new file mode 100644 index 0000000000..9f948751b8 --- /dev/null +++ b/src/Dialect/ONNX/Transforms/DQBinaryQOpt.cpp @@ -0,0 +1,570 @@ +//===- foldDqBinaryQPattern.cpp - Remove DQ-Binary-Q chains -----*- C++ -*-===// +// +// (c) Copyright 2022 - 2025 Advanced Micro Devices, Inc. All Rights Reserved. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" +#include "src/Dialect/ONNX/ONNXOps.hpp" +#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" +#include "src/Pass/Passes.hpp" +#include "llvm/ADT/STLExtras.h" +#include // For std::llround +#include +#include + +using namespace mlir; +using namespace onnx_mlir; + +namespace { + +static ElementsAttr getElementAttributeFromConstant(Value val) { + if (!val) + return nullptr; + if (auto constOp = val.getDefiningOp()) + return mlir::dyn_cast(constOp.getValueAttr()); + return nullptr; +} + +// Equivalent to Python's NoMatch exception: here Nullopt indicates failure. +template +std::optional get_scalar_tensor_value(ONNXConstantOp constOp) { + auto elementsAttr = dyn_cast_or_null(constOp.getValueAttr()); + if (!elementsAttr) + return std::nullopt; + + Type elementType = elementsAttr.getElementType(); + + // Fast path: splat + if (elementsAttr.isSplat()) { + if (elementType.isa()) { + if constexpr (std::is_same_v || std::is_same_v) { + APFloat splatValue = elementsAttr.getSplatValue(); + return static_cast(splatValue.convertToDouble()); + } + } + if (auto intType = elementType.dyn_cast()) { + if constexpr (std::is_integral_v) { + APInt splatValue = elementsAttr.getSplatValue(); + if (intType.isUnsigned()) + return static_cast(splatValue.getZExtValue()); + else + return static_cast(splatValue.getSExtValue()); + } + } + return std::nullopt; + } + + // Non‑splat case: check rank + auto shapedTy = elementsAttr.getType().dyn_cast(); + if (!shapedTy || !shapedTy.hasStaticShape()) + return std::nullopt; + + // Case: rank 0 → scalar element directly + if (shapedTy.getRank() == 0) { + auto firstAttr = *elementsAttr.getValues().begin(); + if (auto fAttr = firstAttr.dyn_cast()) { + if constexpr (std::is_same_v || std::is_same_v) + return static_cast(fAttr.getValueAsDouble()); + } + if (auto iAttr = firstAttr.dyn_cast()) { + if constexpr (std::is_integral_v) + return static_cast(iAttr.getInt()); // signed ok + } + return std::nullopt; + } + + // Case: rank >= 1 → flatten & check all the same + std::set flattenedFP; + std::set flattenedInt; + + if (elementType.isa()) { + if constexpr (std::is_same_v || std::is_same_v) { + for (auto a : elementsAttr.getValues()) + flattenedFP.insert(a.getValueAsDouble()); + if (flattenedFP.size() == 1) + return static_cast(*flattenedFP.begin()); + } + } else if (auto intType = elementType.dyn_cast()) { + if constexpr (std::is_integral_v) { + for (auto a : elementsAttr.getValues()) + flattenedInt.insert(intType.isUnsigned() ? a.getUInt() : a.getInt()); + if (flattenedInt.size() == 1) + return static_cast(*flattenedInt.begin()); + } + } + + return std::nullopt; // mismatch or more than one unique value +} + +template +std::optional get_scalar_tensor_value_from_val(Value value) { + if (!value) { + return std::nullopt; + } + auto constOp = value.getDefiningOp(); + if (!constOp) { + return std::nullopt; + } + return get_scalar_tensor_value(constOp); +} + +static mlir::DenseElementsAttr makeScalarDEA( + mlir::ShapedType likeTy, double d, mlir::Type clampElemTy) { + using namespace mlir; + + auto ranked = likeTy.dyn_cast(); + if (!ranked || !ranked.hasStaticShape() || ranked.getNumElements() != 1) + return {}; + + Type outET = ranked.getElementType(); + Type useET = clampElemTy ? clampElemTy : outET; + + // If target is float, just create a float attr with outET semantics. + if (auto outFT = outET.dyn_cast()) { + // Round in the semantics of useET if it's float; otherwise just use d. + double dv = d; + if (auto useFT = useET.dyn_cast()) { + // Convert through APFloat with 'useET' semantics, then to double. + llvm::APFloat ap(d); + bool loses = false; + ap.convert(useFT.getFloatSemantics(), llvm::APFloat::rmNearestTiesToEven, + &loses); + dv = ap.convertToDouble(); + } + return DenseElementsAttr::get(ranked, FloatAttr::get(outFT, dv)); + } + + // If target is integer, round+clamp as per 'useET' (if integer), then emit as + // outET. + if (auto outIT = outET.dyn_cast()) { + // Decide signedness/width for clamping from useET if it's integer, else + // from outET. + IntegerType clampIT = + useET.isa() ? useET.cast() : outIT; + + int64_t iv = static_cast(std::llround(d)); + const unsigned bw = clampIT.getWidth(); + const bool isSigned = clampIT.isSigned(); + + const int64_t minV = isSigned ? (-(int64_t(1) << (bw - 1))) : 0; + const int64_t maxV = + isSigned ? ((int64_t(1) << (bw - 1)) - 1) : ((int64_t(1) << bw) - 1); + iv = std::min(std::max(iv, minV), maxV); + + // Now re-materialize as the *output* ET (which may differ in width/sign). + // This guarantees the result type matches `likeTy`. + if (auto outSigned = outIT.isSigned()) { + // For signed out type, encode iv as signed. + return DenseElementsAttr::get(ranked, IntegerAttr::get(outIT, iv)); + } else { + // For unsigned out type, encode iv as unsigned (mask to width). + uint64_t u = static_cast(iv); + if (outIT.getWidth() < 64) + u &= ((uint64_t(1) << outIT.getWidth()) - 1); + return DenseElementsAttr::get(ranked, IntegerAttr::get(outIT, u)); + } + } + + return {}; +} + +static void updateInitializer(mlir::PatternRewriter &rewriter, + mlir::Operation *targetOp, mlir::Value oldInit, double newScalar, + mlir::Type clampElemTy) { + using namespace mlir; + + if (!targetOp || !oldInit) + return; + + auto oldCst = oldInit.getDefiningOp(); + if (!oldCst) + return; + + auto likeTy = oldInit.getType().dyn_cast(); + if (!likeTy || !likeTy.hasStaticShape() || likeTy.getNumElements() != 1) + return; + + DenseElementsAttr payload = makeScalarDEA(likeTy, newScalar, clampElemTy); + if (!payload) + return; + + // Single-use-by-target check. + auto singleUseByTarget = [&]() -> bool { + auto it = oldInit.use_begin(), e = oldInit.use_end(); + if (it == e) + return false; + auto *owner = it->getOwner(); + ++it; + return (it == e) && (owner == targetOp); + }; + + if (singleUseByTarget()) { + rewriter.modifyOpInPlace(oldCst, [&] { + oldCst->setAttr("value", payload); + // Keep constant canonical: + oldCst->removeAttr("sparse_value"); + oldCst->removeAttr("value_float"); + oldCst->removeAttr("value_floats"); + oldCst->removeAttr("value_int"); + oldCst->removeAttr("value_ints"); + oldCst->removeAttr("value_string"); + oldCst->removeAttr("value_strings"); + }); + return; + } + + // Multi-use: clone a fresh constant with same result type as oldInit. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(targetOp); + + OperationState st(targetOp->getLoc(), ONNXConstantOp::getOperationName()); + st.addTypes(likeTy); + st.addAttribute("value", payload); + + Operation *raw = Operation::create(st); + rewriter.insert(raw); + auto newCst = llvm::dyn_cast(raw); + if (!newCst) + return; + + // Replace exactly the matching operand. + for (unsigned i = 0, e = targetOp->getNumOperands(); i < e; ++i) { + if (targetOp->getOperand(i) == oldInit) { + targetOp->setOperand(i, newCst.getOutput()); + break; + } + } +} + +static LogicalResult tryRemoveQThenDQChain( + mlir::PatternRewriter &rewriter, mlir::ONNXDequantizeLinearOp dqOp) { + using namespace mlir; + + // Match Q -> DQ + auto qOp = dqOp.getX().template getDefiningOp(); + if (!qOp) { + return failure(); + } + + // 1) Axis / block_size must match + if (qOp.getAxis() != dqOp.getAxis()) { + return failure(); + } + if (qOp.getBlockSize() != dqOp.getBlockSize()) { + return failure(); + } + + // 2) Zero-points must match scalars/splats + auto zpQ = getElementAttributeFromConstant(qOp.getYZeroPoint()); + auto zpDQ = getElementAttributeFromConstant(dqOp.getXZeroPoint()); + if (!zpQ || !zpDQ || zpQ != zpDQ) { + return failure(); + } + + // 3) Scales must match scalars/splats + auto sQ = getElementAttributeFromConstant(qOp.getYScale()); + auto sDQ = getElementAttributeFromConstant(dqOp.getXScale()); + if (!sQ || !sDQ || sQ != sDQ) { + return failure(); + } + + // 4) Data type consistency: input of Q and output of DQ must have same elem + // type. + auto qInTypeOp = qOp.getX().getType(); + auto dqOutTypeOp = dqOp.getResult().getType(); + + if (auto qInTensorType = qInTypeOp.dyn_cast()) { + if (auto dqOutTensorType = dqOutTypeOp.dyn_cast()) { + if (dqOutTensorType.getElementType() != qInTensorType.getElementType()) { + return failure(); + } + } else { + return failure(); + } + } else { + return failure(); + } + + // Replace DQ with Q's float input; erase Q if it becomes dead. + rewriter.replaceOp(dqOp, qOp.getX()); + if (qOp->use_empty()) { + rewriter.eraseOp(qOp); + } + + return success(); +} + +template +struct FoldBinaryThroughQDQ : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + +private: + struct MatchState { + ONNXDequantizeLinearOp dequantActivationOp = nullptr; + mlir::Type ScaleDtype; + mlir::Type zeroPointDtype; + double kValue = 0.0; + double dstScale = 0.0; + int64_t dstZeroPoint = 0; + double newScale = 0.0; + int64_t newZp = 0; + }; + + LogicalResult match_qdq(MatchState &state, ONNXDequantizeLinearOp dq1, + ONNXDequantizeLinearOp dq2) const { + + ONNXDequantizeLinearOp constantDqOp = nullptr; + ONNXConstantOp constantSourceOp = nullptr; + + // Case 1: Direct ConstantOp as input to the DQ. + if (auto constOp = dq1.getX().getDefiningOp()) { + constantDqOp = dq1; + state.dequantActivationOp = dq2; + constantSourceOp = constOp; + } else if (auto constOp = dq2.getX().getDefiningOp()) { + constantDqOp = dq2; + state.dequantActivationOp = dq1; + constantSourceOp = constOp; + } + // Case 2: The input to the DQ op comes from a chain whose input is a + // constant. + else if (auto intermediateOp = dq1.getX().getDefiningOp()) { + if (auto constOp = + intermediateOp->getOperand(0).getDefiningOp()) { + constantDqOp = dq1; + state.dequantActivationOp = dq2; + constantSourceOp = constOp; + } + } else if (auto intermediateOp = dq2.getX().getDefiningOp()) { + if (auto constOp = + intermediateOp->getOperand(0).getDefiningOp()) { + constantDqOp = dq2; + state.dequantActivationOp = dq1; + constantSourceOp = constOp; + } + } + + if (!constantDqOp || !constantSourceOp || !state.dequantActivationOp) { + return failure(); + } + + // Find kvalue and store scale_dtype and zeroPointDtype + { + auto scalar_value_opt = + get_scalar_tensor_value(constantSourceOp); + if (!scalar_value_opt) { + return failure(); + } + Value scaleVal = constantDqOp.getXScale(); + Value zpVal = constantDqOp.getXZeroPoint(); + auto scale_value_opt = get_scalar_tensor_value_from_val(scaleVal); + auto zp_value_opt = get_scalar_tensor_value_from_val(zpVal); + if (!scale_value_opt || !zp_value_opt) { + return failure(); + } + // Calculate and store kValue. + state.kValue = (*scalar_value_opt - *zp_value_opt) * *scale_value_opt; + + // # store dtype for creating new initializers with the same dtype + state.ScaleDtype = + mlir::cast(scaleVal.getType()).getElementType(); + state.zeroPointDtype = + mlir::cast(zpVal.getType()).getElementType(); + } + return success(); + } + + LogicalResult match_binary_op(MatchState &state, BinOp binaryOp) const { + ONNXConstantOp constantOp = nullptr; + + Value lhs = binaryOp.getOperand(0); + Value rhs = binaryOp.getOperand(1); + + // -------- Case A: lhs is DQ, rhs is Constant -------- + if (auto dqOp = lhs.getDefiningOp()) { + if (auto constOp = rhs.getDefiningOp()) { + state.dequantActivationOp = dqOp; + constantOp = constOp; + } + } + // -------- Case A reversed -------- + else if (auto dqOp = rhs.getDefiningOp()) { + if (auto constOp = lhs.getDefiningOp()) { + state.dequantActivationOp = dqOp; + constantOp = constOp; + } + } + + // -------- Fill state values for Case A and Case A reversed -------- + if (state.dequantActivationOp && constantOp) { + auto kValueOpt = get_scalar_tensor_value(constantOp); + if (!kValueOpt) { + return failure(); + } + state.kValue = kValueOpt.value(); + return success(); + } + + // -------- Case B: both inputs are DQ -------- + auto dqOp1 = lhs.getDefiningOp(); + auto dqOp2 = rhs.getDefiningOp(); + + if (dqOp1 && dqOp2) { + if (failed(match_qdq(state, dqOp1, dqOp2))) + return failure(); + return success(); + } + return failure(); + } + + LogicalResult check_needed_values( + const MatchState &state, Operation *binaryOp) const { + if (state.kValue == 0.0) { + if (isa(binaryOp)) { + return failure(); + } + } + if (state.dstScale == 0.0) { + if (isa(binaryOp)) { + return failure(); + } + } + return success(); + } + + static bool compute_new_scale_and_zp_values( + MatchState &state, Operation *binaryOp) { + double newScale = state.dstScale; + double newZpFloat = static_cast(state.dstZeroPoint); + const double kVal = state.kValue; + + if (isa(binaryOp)) { + newZpFloat -= (kVal / newScale); + + } else if (isa(binaryOp)) { + newZpFloat += (kVal / newScale); + + } else if (isa(binaryOp)) { + newScale *= kVal; + + } else if (isa(binaryOp)) { + newScale /= kVal; + + } else { + return false; + } + + int64_t newZp = static_cast(std::llround(newZpFloat)); + state.newScale = newScale; + state.newZp = newZp; + + return true; + } + +public: + LogicalResult matchAndRewrite( + BinOp op, PatternRewriter &rewriter) const override { + + // STEP 1: Match begin: Assuming only one user + auto quantOutputOp = dyn_cast(*op->user_begin()); + if (!quantOutputOp) { + return failure(); + } + + // Instantiate the state struct + MatchState state; + + // STEP 2 + if (failed(match_binary_op(state, op))) { + return failure(); + } + + // Store the value of the scale and zero point of the destination node + { + Value scaleVal = state.dequantActivationOp.getXScale(); + Value zpVal = state.dequantActivationOp.getXZeroPoint(); + auto scale_value_opt = get_scalar_tensor_value_from_val(scaleVal); + auto zp_value_opt = get_scalar_tensor_value_from_val(zpVal); + if (!scale_value_opt || !zp_value_opt) { + return failure(); + } + state.dstScale = scale_value_opt.value(); + state.dstZeroPoint = zp_value_opt.value(); + } + + // STEP 3 + if (failed(check_needed_values(state, op))) { + return failure(); + } + + // STEP 4 -Modify + if (!compute_new_scale_and_zp_values(state, op)) { + return failure(); + } + + // STEP 5: call initializer based on the binary op + ONNXDequantizeLinearOp dqAct = state.dequantActivationOp; + if constexpr (std::is_same_v || + std::is_same_v) { + Value zpVal = dqAct.getXZeroPoint(); + updateInitializer(rewriter, dqAct.getOperation(), zpVal, + static_cast(state.newZp), state.zeroPointDtype); + + } else if constexpr (std::is_same_v || + std::is_same_v) { + Value scaleVal = dqAct.getXScale(); + updateInitializer(rewriter, dqAct.getOperation(), scaleVal, + state.newScale, state.ScaleDtype); + } + + // STEP 6: Remove binary op + rewriter.replaceOp(op, dqAct.getResult()); + + // STEP 7: Remove Q->DQ chain + for (Operation *user : quantOutputOp.getY().getUsers()) { + if (auto tailDQ = llvm::dyn_cast(user)) { + (void)tryRemoveQThenDQChain(rewriter, tailDQ); + } + } + return success(); + } +}; + +struct FoldDQBinaryQPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FoldDQBinaryQPass) + + StringRef getArgument() const final { return "dq-binary-q-opt-onnx-to-onnx"; } + StringRef getDescription() const final { + return "Fold Add/Sub/Mul/Div through Q/DQ by updating scale/zero_point, " + "then remove trivial Q->DQ chains when safe."; + } + + void runOnOperation() override { + auto function = getOperation(); + RewritePatternSet patterns(&getContext()); + patterns + .add, FoldBinaryThroughQDQ, + FoldBinaryThroughQDQ, FoldBinaryThroughQDQ>( + &getContext()); + if (failed(applyPatternsAndFoldGreedily(function, std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +namespace onnx_mlir { +std::unique_ptr createFoldDQBinaryQPass() { + return std::make_unique(); +} +} // namespace onnx_mlir \ No newline at end of file diff --git a/src/Pass/Passes.hpp b/src/Pass/Passes.hpp index 596ff251f4..8e4855e6dc 100644 --- a/src/Pass/Passes.hpp +++ b/src/Pass/Passes.hpp @@ -57,6 +57,7 @@ std::unique_ptr createConstPropONNXToONNXPass(); std::unique_ptr createQDQAroundOpOptONNXToONNXPass(); std::unique_ptr createQDQOptONNXToONNXPass(); +std::unique_ptr createFoldDQBinaryQPass(); /// Pass for instrument the ops in specific stage. std::unique_ptr createInstrumentPass(); diff --git a/src/Tools/onnx-mlir-opt/RegisterPasses.cpp b/src/Tools/onnx-mlir-opt/RegisterPasses.cpp index 58fd210505..67750006f2 100644 --- a/src/Tools/onnx-mlir-opt/RegisterPasses.cpp +++ b/src/Tools/onnx-mlir-opt/RegisterPasses.cpp @@ -72,8 +72,13 @@ void registerOMPasses(int optLevel) { }); mlir::registerPass([]() -> std::unique_ptr { - return createQDQAroundOpOptONNXToONNXPass(); + return createFoldDQBinaryQPass(); }); + + mlir::registerPass([]() -> std::unique_ptr { + return createQDQAroundOpOptONNXToONNXPass(); + }); + mlir::registerPass([]() -> std::unique_ptr { return createQDQOptONNXToONNXPass(); diff --git a/test/mlir/onnx/onnx_remove_add.mlir b/test/mlir/onnx/onnx_remove_add.mlir new file mode 100644 index 0000000000..ea7451241b --- /dev/null +++ b/test/mlir/onnx/onnx_remove_add.mlir @@ -0,0 +1,255 @@ +// RUN: onnx-mlir-opt --dq-binary-q-opt-onnx-to-onnx %s --split-input-file | FileCheck %s + +// 1) dq1-dq2(const input)-add-q-dq. remove->add,q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern1a +// CHECK: %[[ZP:.*]] = onnx.Constant dense<336> : tensor +// CHECK-NOT: onnx.Add +// CHECK-NOT: onnx.QuantizeLinear +// CHECK: return +// CHECK-NOT: onnx.DequantizeLinear +func.func @test_removebinary_pattern1a(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<39664> : tensor +%3 = onnx.Constant dense<40000> : tensor +%4 = "onnx.DequantizeLinear"(%2, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor +%5 = "onnx.DequantizeLinear"(%arg0, %1, %3) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%6 = "onnx.Add"(%5, %4) : (tensor<1x1x1x128xf32>, tensor) -> tensor<1x1x1x128xf32> +%7 = "onnx.QuantizeLinear"(%6, %1, %0) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%8 = "onnx.DequantizeLinear"(%7, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %8 : tensor<1x1x1x128xf32> +} + +// ----- +// 2) dq1-dq2(const input)-add-q-dq. remove->add,q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern1b +// CHECK: %[[ZP:.*]] = onnx.Constant dense<336> : tensor +// CHECK-NOT: onnx.Add +// CHECK-NOT: onnx.QuantizeLinear +// CHECK: return +// CHECK-NOT: onnx.DequantizeLinear +func.func @test_removebinary_pattern1b(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<39664> : tensor +%3 = onnx.Constant dense<40000> : tensor +%4 = "onnx.DequantizeLinear"(%2, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor +%5 = "onnx.DequantizeLinear"(%arg0, %1, %3) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%6 = "onnx.Add"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +%7 = "onnx.QuantizeLinear"(%6, %1, %0) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%8 = "onnx.DequantizeLinear"(%7, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %8 : tensor<1x1x1x128xf32> +} + +// ----- +// 3) dq1-dq2(const input)-Sub-q-dq. remove->Sub,q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern1c +// CHECK-NOT: onnx.Add +// CHECK-NOT: onnx.QuantizeLinear +func.func @test_removebinary_pattern1c(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = onnx.Constant dense<0> : tensor +%5 = "onnx.Identity"(%4) : (tensor) -> tensor +%6 = "onnx.DequantizeLinear"(%5, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor +%7 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%8 = "onnx.Add"(%7, %6) : (tensor<1x1x1x128xf32>, tensor) -> tensor<1x1x1x128xf32> +%9 = "onnx.QuantizeLinear"(%8, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%10 = "onnx.DequantizeLinear"(%9, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %10 : tensor<1x1x1x128xf32> +} + +// ----- +// 4) dq1-dq2(const input)-Sub-q-dq. remove->Sub,q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern1d +// CHECK-NOT: onnx.Add +// CHECK-NOT: onnx.QuantizeLinear +func.func @test_removebinary_pattern1d(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = onnx.Constant dense<0> : tensor +%5 = "onnx.Identity"(%4) : (tensor) -> tensor +%6 = "onnx.DequantizeLinear"(%5, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor +%7 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%8 = "onnx.Add"(%6, %7) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +%9 = "onnx.QuantizeLinear"(%8, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%10 = "onnx.DequantizeLinear"(%9, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %10 : tensor<1x1x1x128xf32> +} +//----- +// 5) dq1-const-add-q-dq. remove->add, q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern2a +// CHECK: %[[ZP:.*]] = onnx.Constant dense<100> : tensor +// CHECK-NOT: onnx.Add +// CHECK-NOT: onnx.QuantizeLinear +// CHECK: return +// CHECK-NOT: onnx.DequantizeLinear +func.func @test_removebinary_pattern2a(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<100> : tensor +%1 = onnx.Constant dense<1.000000e+01> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = onnx.Constant dense<1.000000e+00> : tensor +%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%6 = "onnx.Add"(%5, %4) : (tensor<1x1x1x128xf32>, tensor) -> tensor<1x1x1x128xf32> +%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %8 : tensor<1x1x1x128xf32> +} +//----- +// 6) const-dq1-add-q-dq. remove->add,q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern2b +// CHECK-NOT: onnx.Add +// CHECK-NOT: onnx.QuantizeLinear +// CHECK: return +// CHECK-NOT: onnx.DequantizeLinear +func.func @test_removebinary_pattern2b(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = onnx.Constant dense<-1.000000e+04> : tensor +%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%6 = "onnx.Add"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %8 : tensor<1x1x1x128xf32> +} +//----- +// 7) const-dq1-add-q-dq. kval=0. remove->add,q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern3a +// CHECK-NOT: onnx.Add +// CHECK-NOT: onnx.QuantizeLinear +// CHECK: return +// CHECK-NOT: onnx.DequantizeLinear +func.func @test_removebinary_pattern3a(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = onnx.Constant dense<0.000000e+00> : tensor +%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%6 = "onnx.Add"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %8 : tensor<1x1x1x128xf32> +} +//----- +// 8) const-dq1-Sub-q-dq. dst_scale=0. remove->Sub,q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern3b +// CHECK: onnx.Add +// CHECK: onnx.QuantizeLinear +func.func @test_removebinary_pattern3b(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<0.000000e+00> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = onnx.Constant dense<-1.000000e+04> : tensor +%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%6 = "onnx.Add"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %8 : tensor<1x1x1x128xf32> +} +//----- +// 9) dq1-dq2(const input)-add-q-dq. remove->add,q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern4 +// CHECK-NOT: onnx.Add +// CHECK: onnx.QuantizeLinear +func.func @test_removebinary_pattern4(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<39664> : tensor +%4 = onnx.Constant dense<2.57987776E-5> : tensor +%5 = "onnx.DequantizeLinear"(%2, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor +%6 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%7 = "onnx.Add"(%5, %6) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +%8 = "onnx.QuantizeLinear"(%7, %1, %0) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%9 = "onnx.DequantizeLinear"(%8, %4, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %9 : tensor<1x1x1x128xf32> +} +//----- +// 10) const-dq1-Sub-tanh. remove->none +// CHECK-LABEL: func.func @test_removebinary_pattern5 +// CHECK: onnx.Add +// CHECK: onnx.Tanh +func.func @test_removebinary_pattern5(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<39664> : tensor +%4 = onnx.Constant dense<2.57987776E-5> : tensor +%5 = "onnx.DequantizeLinear"(%2, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor +%6 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%7 = "onnx.Add"(%5, %6) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +%8 = "onnx.Tanh"(%7) : (tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +return %8 : tensor<1x1x1x128xf32> +} +//----- +// 11) dq1-dq2-add-q-dq1-dq2-mul-Q-DQ. multi-use of scale and zp of dq-act before binary op. remove->mul, add +// CHECK-LABEL: func.func @test_removebinary_pattern6 +// CHECK-NOT: onnx.Add +// CHECK-NOT: onnx.Add +func.func @test_removebinary_pattern6(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = onnx.Constant dense<39664> : tensor +%5 = onnx.Constant dense<2.57987776E-5> : tensor +%6 = "onnx.DequantizeLinear"(%2, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor +%7 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%8 = "onnx.Div"(%6, %7) {onnx_node_name = "/bert/Sub"} : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +%9 = "onnx.QuantizeLinear"(%8, %1, %0) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%10 = "onnx.DequantizeLinear"(%9, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%11 = "onnx.DequantizeLinear"(%0, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor +%12 = "onnx.Add"(%10, %11) : (tensor<1x1x1x128xf32>, tensor) -> tensor<1x1x1x128xf32> +%13 = "onnx.QuantizeLinear"(%12, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%14 = "onnx.DequantizeLinear"(%13, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %14 : tensor<1x1x1x128xf32> +} +//----- +// +// 12) dq1-dq2(const input, per-axis length-2 on axis=0)-mul-q-dq. +// vectors wiht same values -> fusion +// CHECK-LABEL: func.func @test_removebinary_pattern7a +// CHECK-NOT: onnx.Add +// CHECK-NOT: onnx.QuantizeLinear +func.func @test_removebinary_pattern7a(%arg0: tensor<2x1x1x128xui16>) -> tensor<2x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor<2xui16> +%1 = onnx.Constant dense<1.52590219E-5> : tensor<2xf32> +%2 = onnx.Constant dense<65535> : tensor<2xui16> +%3 = onnx.Constant dense<0.152590215> : tensor<2xf32> +%4 = onnx.Constant dense<0> : tensor<2x1x1x1xui16> +%5 = "onnx.DequantizeLinear"(%4, %3, %2) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x1xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x1xf32> +%6 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x128xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xf32> +%7 = "onnx.Add"(%6, %5) : (tensor<2x1x1x128xf32>, tensor<2x1x1x1xf32>) -> tensor<2x1x1x128xf32> +%8 = "onnx.QuantizeLinear"(%7, %3, %2) {axis = 0 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<2x1x1x128xf32>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xui16> +%9 = "onnx.DequantizeLinear"(%8, %3, %2) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x128xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xf32> +return %9 : tensor<2x1x1x128xf32> +} +//----- +// +// 13) dq1-dq2(const input, per-axis length-2 on axis=0)-mul-q-dq. +// vectors wiht different values -> no fusion +// CHECK-LABEL: func.func @test_removebinary_pattern7b +// CHECK: onnx.Add +// CHECK: onnx.QuantizeLinear +func.func @test_removebinary_pattern7b(%arg0: tensor<2x1x1x128xui16>) -> tensor<2x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor<2xui16> +%1 = onnx.Constant dense<1.52590219E-5> : tensor<2xf32> +%2 = onnx.Constant dense<[65535, 1]> : tensor<2xui16> +%3 = onnx.Constant dense<0.152590215> : tensor<2xf32> +%4 = onnx.Constant dense<0> : tensor<2x1x1x1xui16> +%5 = "onnx.DequantizeLinear"(%4, %3, %2) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x1xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x1xf32> +%6 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x128xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xf32> +%7 = "onnx.Add"(%6, %5) : (tensor<2x1x1x128xf32>, tensor<2x1x1x1xf32>) -> tensor<2x1x1x128xf32> +%8 = "onnx.QuantizeLinear"(%7, %3, %2) {axis = 0 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<2x1x1x128xf32>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xui16> +%9 = "onnx.DequantizeLinear"(%8, %3, %2) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x128xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xf32> +return %9 : tensor<2x1x1x128xf32> +} diff --git a/test/mlir/onnx/onnx_remove_div.mlir b/test/mlir/onnx/onnx_remove_div.mlir new file mode 100644 index 0000000000..d73b65846f --- /dev/null +++ b/test/mlir/onnx/onnx_remove_div.mlir @@ -0,0 +1,249 @@ +// RUN: onnx-mlir-opt --dq-binary-q-opt-onnx-to-onnx %s -split-input-file | FileCheck %s + +// 1) dq1-dq2(const input)-div-q-dq. remove->div,q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern1a +// CHECK: %[[SCALE:.*]] = onnx.Constant dense<-1.52590218E-9> : tensor +// CHECK-NOT: onnx.Div +// CHECK-NOT: onnx.QuantizeLinear +// CHECK: return +// CHECK-NOT: onnx.DequantizeLinear +func.func @test_removebinary_pattern1a(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = "onnx.DequantizeLinear"(%0, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor +%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%6 = "onnx.Div"(%5, %4) : (tensor<1x1x1x128xf32>, tensor) -> tensor<1x1x1x128xf32> +%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %8 : tensor<1x1x1x128xf32> +} +//----- +// 2) dq1-dq2(const input)-div-q-dq. remove->div,q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern1b +// CHECK: %[[SCALE:.*]] = onnx.Constant dense<-1.52590218E-9> : tensor +// CHECK-NOT: onnx.Div +// CHECK-NOT: onnx.QuantizeLinear +// CHECK: return +// CHECK-NOT: onnx.DequantizeLinear +func.func @test_removebinary_pattern1b(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = "onnx.DequantizeLinear"(%0, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor +%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%6 = "onnx.Div"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %8 : tensor<1x1x1x128xf32> +} +//----- +// 3) dq1-dq2(const input)-div-q-dq. remove->div,q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern1c +// CHECK-NOT: onnx.Div +// CHECK-NOT: onnx.QuantizeLinear +func.func @test_removebinary_pattern1c(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = onnx.Constant dense<0> : tensor +%5 = "onnx.Identity"(%4) : (tensor) -> tensor +%6 = "onnx.DequantizeLinear"(%5, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor +%7 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%8 = "onnx.Div"(%7, %6) : (tensor<1x1x1x128xf32>, tensor) -> tensor<1x1x1x128xf32> +%9 = "onnx.QuantizeLinear"(%8, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%10 = "onnx.DequantizeLinear"(%9, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %10 : tensor<1x1x1x128xf32> +} +//----- +// 4) dq1-dq2(const input)-div-q-dq. remove->div,q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern1d +// CHECK-NOT: onnx.Div +// CHECK-NOT: onnx.QuantizeLinear +func.func @test_removebinary_pattern1d(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = onnx.Constant dense<0> : tensor +%5 = "onnx.Identity"(%4) : (tensor) -> tensor +%6 = "onnx.DequantizeLinear"(%5, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor +%7 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%8 = "onnx.Div"(%6, %7) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +%9 = "onnx.QuantizeLinear"(%8, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%10 = "onnx.DequantizeLinear"(%9, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %10 : tensor<1x1x1x128xf32> +} +//----- +// 5) dq1-const-mul-q-dq. remove->mul,q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern2a +// CHECK: %[[SCALE:.*]] = onnx.Constant dense<-1.49999991E-7> : tensor +// CHECK-NOT: onnx.Mul +// CHECK-NOT: onnx.QuantizeLinear +// CHECK: return +// CHECK-NOT: onnx.DequantizeLinear +func.func @test_removebinary_pattern2a(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.500000e-05> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = onnx.Constant dense<-1.000000e+02> : tensor +%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%6 = "onnx.Div"(%5, %4) : (tensor<1x1x1x128xf32>, tensor) -> tensor<1x1x1x128xf32> +%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %8 : tensor<1x1x1x128xf32> +} +//----- +// 6) const-dq1-div-q-dq. remove->div,q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern2b +// CHECK-NOT: onnx.Div +// CHECK-NOT: onnx.QuantizeLinear +// CHECK: return +// CHECK-NOT: onnx.DequantizeLinear +func.func @test_removebinary_pattern2b(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = onnx.Constant dense<-1.000000e+04> : tensor +%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%6 = "onnx.Div"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %8 : tensor<1x1x1x128xf32> +} +//----- +// 7) const-dq1-div-q-dq. kval=0. remove->div,q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern3c +// CHECK: onnx.Div +// CHECK: onnx.QuantizeLinear +func.func @test_removebinary_pattern3c(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = onnx.Constant dense<0.000000e+00> : tensor +%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%6 = "onnx.Div"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %8 : tensor<1x1x1x128xf32> +} +//----- +// 8) const-dq1-div-q-dq. dst_scale=0. remove->div,q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern3b +// CHECK-NOT: onnx.Div +// CHECK-NOT: onnx.QuantizeLinear +// CHECK: return +// CHECK-NOT: onnx.DequantizeLinear +func.func @test_removebinary_pattern3b(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<0.000000e+00> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = onnx.Constant dense<-1.000000e+04> : tensor +%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%6 = "onnx.Div"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %8 : tensor<1x1x1x128xf32> +} +//----- +// 9) const-dq1-div-q-dq. q!=dq. remove->only div +// CHECK-LABEL: func.func @test_removebinary_pattern4 +// CHECK-NOT: onnx.Div +// CHECK: onnx.QuantizeLinear +func.func @test_removebinary_pattern4(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<0.000000e+00> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = onnx.Constant dense<-1.000000e+04> : tensor +%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%6 = "onnx.Div"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%8 = "onnx.DequantizeLinear"(%7, %1, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %8 : tensor<1x1x1x128xf32> +} +//----- +// 10) const-dq1-div-tanh. remove->none +// CHECK-LABEL: func.func @test_removebinary_pattern5 +// CHECK: onnx.Div +// CHECK: onnx.Tanh +func.func @test_removebinary_pattern5(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<0.000000e+00> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = onnx.Constant dense<-1.000000e+04> : tensor +%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%6 = "onnx.Div"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +%7 = "onnx.Tanh"(%6) : (tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +return %7 : tensor<1x1x1x128xf32> +} +//----- +// 11) dq1-dq2-sub-q-dq1-dq2-div-Q-DQ. multi-use of scale and zp of dq-act before binary op. remove->div, sub +// CHECK-LABEL: func.func @test_removebinary_pattern6 +// CHECK-NOT: onnx.Div +// CHECK-NOT: onnx.Sub +func.func @test_removebinary_pattern6(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = onnx.Constant dense<39664> : tensor +%5 = onnx.Constant dense<2.57987776E-5> : tensor +%6 = "onnx.DequantizeLinear"(%2, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor +%7 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%8 = "onnx.Sub"(%6, %7) {onnx_node_name = "/bert/Sub"} : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +%9 = "onnx.QuantizeLinear"(%8, %1, %0) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%10 = "onnx.DequantizeLinear"(%9, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%11 = "onnx.DequantizeLinear"(%0, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor +%12 = "onnx.Div"(%10, %11) : (tensor<1x1x1x128xf32>, tensor) -> tensor<1x1x1x128xf32> +%13 = "onnx.QuantizeLinear"(%12, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%14 = "onnx.DequantizeLinear"(%13, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %14 : tensor<1x1x1x128xf32> +} +//----- +// +// 12) dq1-dq2(const input, per-axis length-2 on axis=0)-div-q-dq. +// Keep Div and QuantizeLinear present. +// CHECK-LABEL: func.func @test_removebinary_pattern7a +// CHECK-NOT: onnx.Div +// CHECK-NOT: onnx.QuantizeLinear +func.func @test_removebinary_pattern7a(%arg0: tensor<2x1x1x128xui16>) -> tensor<2x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor<2xui16> +%1 = onnx.Constant dense<1.52590219E-5> : tensor<2xf32> +%2 = onnx.Constant dense<65535> : tensor<2xui16> +%3 = onnx.Constant dense<0.152590215> : tensor<2xf32> +%4 = onnx.Constant dense<0> : tensor<2x1x1x1xui16> +%5 = "onnx.DequantizeLinear"(%4, %3, %2) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x1xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x1xf32> +%6 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x128xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xf32> +%7 = "onnx.Div"(%6, %5) : (tensor<2x1x1x128xf32>, tensor<2x1x1x1xf32>) -> tensor<2x1x1x128xf32> +%8 = "onnx.QuantizeLinear"(%7, %3, %2) {axis = 0 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<2x1x1x128xf32>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xui16> +%9 = "onnx.DequantizeLinear"(%8, %3, %2) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x128xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xf32> +return %9 : tensor<2x1x1x128xf32> +} +//----- +// 13) dq1-dq2(const input, per-axis length-2 on axis=0)-div-q-dq. +// Keep Div and QuantizeLinear present. +// CHECK-LABEL: func.func @test_removebinary_pattern7b +// CHECK: onnx.Div +// CHECK: onnx.QuantizeLinear +func.func @test_removebinary_pattern7b(%arg0: tensor<2x1x1x128xui16>) -> tensor<2x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor<2xui16> +%1 = onnx.Constant dense<1.52590219E-5> : tensor<2xf32> +%2 = onnx.Constant dense<[65535, 1]> : tensor<2xui16> +%3 = onnx.Constant dense<0.152590215> : tensor<2xf32> +%4 = onnx.Constant dense<0> : tensor<2x1x1x1xui16> +%5 = "onnx.DequantizeLinear"(%4, %3, %2) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x1xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x1xf32> +%6 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x128xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xf32> +%7 = "onnx.Div"(%6, %5) : (tensor<2x1x1x128xf32>, tensor<2x1x1x1xf32>) -> tensor<2x1x1x128xf32> +%8 = "onnx.QuantizeLinear"(%7, %3, %2) {axis = 0 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<2x1x1x128xf32>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xui16> +%9 = "onnx.DequantizeLinear"(%8, %3, %2) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x128xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xf32> +return %9 : tensor<2x1x1x128xf32> +} diff --git a/test/mlir/onnx/onnx_remove_mul.mlir b/test/mlir/onnx/onnx_remove_mul.mlir new file mode 100644 index 0000000000..ad21131746 --- /dev/null +++ b/test/mlir/onnx/onnx_remove_mul.mlir @@ -0,0 +1,254 @@ +// RUN: onnx-mlir-opt --dq-binary-q-opt-onnx-to-onnx %s -split-input-file | FileCheck %s + +// 1) dq1-dq2(const input)-mul-q-dq. remove->mul,q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern1a +// CHECK: %[[SCALE:.*]] = onnx.Constant dense<-0.152590215> : tensor +// CHECK-NOT: onnx.Mul +// CHECK-NOT: onnx.QuantizeLinear +// CHECK: return +// CHECK-NOT: onnx.DequantizeLinear +func.func @test_removebinary_pattern1a(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = "onnx.DequantizeLinear"(%0, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor +%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%6 = "onnx.Mul"(%5, %4) : (tensor<1x1x1x128xf32>, tensor) -> tensor<1x1x1x128xf32> +%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %8 : tensor<1x1x1x128xf32> +} +//----- + +// 2) dq1-dq2(const input)-mul-q-dq. remove->mul,q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern1b +// CHECK: %[[SCALE:.*]] = onnx.Constant dense<-0.152590215> : tensor +// CHECK-NOT: onnx.Mul +// CHECK-NOT: onnx.QuantizeLinear +// CHECK: return +// CHECK-NOT: onnx.DequantizeLinear +func.func @test_removebinary_pattern1b(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = "onnx.DequantizeLinear"(%0, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor +%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%6 = "onnx.Mul"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %8 : tensor<1x1x1x128xf32> +} +//----- +// Test fusing a DQ -> Mul -> Q -> DQ pattern. +// The constant input to the Mul is produced by a chain: Constant -> Identity -> DequantizeLinear. +// The pass should look through the Identity op, perform the fusion, and remove the redundant Q->DQ chain. +// 3) dq1-dq2(const input)-mul-q-dq. remove->mul,q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern1c +// CHECK-NOT: onnx.Mul +// CHECK-NOT: onnx.QuantizeLinear +func.func @test_removebinary_pattern1c(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = onnx.Constant dense<0> : tensor +%5 = "onnx.Identity"(%4) : (tensor) -> tensor +%6 = "onnx.DequantizeLinear"(%5, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor +%7 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%8 = "onnx.Mul"(%7, %6) : (tensor<1x1x1x128xf32>, tensor) -> tensor<1x1x1x128xf32> +%9 = "onnx.QuantizeLinear"(%8, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%10 = "onnx.DequantizeLinear"(%9, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %10 : tensor<1x1x1x128xf32> +} +//----- +// 4) dq1-dq2(const input)-mul-q-dq. remove->mul,q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern1d +// CHECK-NOT: onnx.Mul +// CHECK-NOT: onnx.QuantizeLinear +func.func @test_removebinary_pattern1d(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = onnx.Constant dense<0> : tensor +%5 = "onnx.Identity"(%4) : (tensor) -> tensor +%6 = "onnx.DequantizeLinear"(%5, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor +%7 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%8 = "onnx.Mul"(%6, %7) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +%9 = "onnx.QuantizeLinear"(%8, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%10 = "onnx.DequantizeLinear"(%9, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %10 : tensor<1x1x1x128xf32> +} +//----- +// 5) dq1-const-mul-q-dq. remove->mul,q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern2a +// CHECK: %[[SCALE:.*]] = onnx.Constant dense<-0.00152590219> : tensor +// CHECK-NOT: onnx.Mul +// CHECK-NOT: onnx.QuantizeLinear +// CHECK: return +// CHECK-NOT: onnx.DequantizeLinear +func.func @test_removebinary_pattern2a(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = onnx.Constant dense<-1.000000e+02> : tensor +%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%6 = "onnx.Mul"(%5, %4) : (tensor<1x1x1x128xf32>, tensor) -> tensor<1x1x1x128xf32> +%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %8 : tensor<1x1x1x128xf32> +} +//----- +// 6) const-dq1-mul-q-dq. remove->mul,q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern2b +// CHECK-NOT: onnx.Mul +// CHECK-NOT: onnx.QuantizeLinear +// CHECK: return +// CHECK-NOT: onnx.DequantizeLinear +func.func @test_removebinary_pattern2b(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = onnx.Constant dense<-1.000000e+04> : tensor +%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%6 = "onnx.Mul"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %8 : tensor<1x1x1x128xf32> +} +//----- +// 7) const-dq1-mul-q-dq. kval=0. remove->mul,q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern3a +// CHECK-NOT: onnx.Mul +// CHECK-NOT: onnx.QuantizeLinear +// CHECK: return +// CHECK-NOT: onnx.DequantizeLinear +func.func @test_removebinary_pattern3a(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = onnx.Constant dense<0.000000e+00> : tensor +%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%6 = "onnx.Mul"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %8 : tensor<1x1x1x128xf32> +} +//----- +// 8) const-dq1-mul-q-dq. dst_scale=0. remove->mul,q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern3b +// CHECK-NOT: onnx.Mul +// CHECK-NOT: onnx.QuantizeLinear +// CHECK: return +// CHECK-NOT: onnx.DequantizeLinear +func.func @test_removebinary_pattern3b(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<0.000000e+00> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = onnx.Constant dense<-1.000000e+04> : tensor +%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%6 = "onnx.Mul"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %8 : tensor<1x1x1x128xf32> +} +//----- +// 9) const-dq1-mul-q-dq. q!=dq. remove->only mul +// CHECK-LABEL: func.func @test_removebinary_pattern4 +// CHECK-NOT: onnx.Mul +// CHECK: onnx.QuantizeLinear +func.func @test_removebinary_pattern4(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<0.000000e+00> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = onnx.Constant dense<-1.000000e+04> : tensor +%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%6 = "onnx.Mul"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%8 = "onnx.DequantizeLinear"(%7, %1, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %8 : tensor<1x1x1x128xf32> +} +//----- +// 10) const-dq1-mul-tanh. remove->none +// CHECK-LABEL: func.func @test_removebinary_pattern5 +// CHECK: onnx.Mul +// CHECK: onnx.Tanh +func.func @test_removebinary_pattern5(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<0.000000e+00> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = onnx.Constant dense<-1.000000e+04> : tensor +%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%6 = "onnx.Mul"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +%7 = "onnx.Tanh"(%6) : (tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +return %7 : tensor<1x1x1x128xf32> +} +//----- +// 11) dq1-dq2-sub-q-dq1-dq2-mul-Q-DQ. multi-use of scale and zp of dq-act before binary op. remove->mul, sub +// CHECK-LABEL: func.func @test_removebinary_pattern6 +// CHECK-NOT: onnx.Mul +// CHECK-NOT: onnx.Sub +func.func @test_removebinary_pattern6(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = onnx.Constant dense<39664> : tensor +%5 = onnx.Constant dense<2.57987776E-5> : tensor +%6 = "onnx.DequantizeLinear"(%2, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor +%7 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%8 = "onnx.Add"(%6, %7) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +%9 = "onnx.QuantizeLinear"(%8, %1, %0) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%10 = "onnx.DequantizeLinear"(%9, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%11 = "onnx.DequantizeLinear"(%0, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor +%12 = "onnx.Mul"(%10, %11) : (tensor<1x1x1x128xf32>, tensor) -> tensor<1x1x1x128xf32> +%13 = "onnx.QuantizeLinear"(%12, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%14 = "onnx.DequantizeLinear"(%13, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %14 : tensor<1x1x1x128xf32> +} +//----- +// 12) dq1-dq2(const input, per-axis length-2 on axis=0)-mul-q-dq. +// Keep Mul and QuantizeLinear present. +// CHECK-LABEL: func.func @test_removebinary_pattern7a +// CHECK-NOT: onnx.Mul +// CHECK-NOT: onnx.QuantizeLinear +func.func @test_removebinary_pattern7a(%arg0: tensor<2x1x1x128xui16>) -> tensor<2x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor<2xui16> +%1 = onnx.Constant dense<1.52590219E-5> : tensor<2xf32> +%2 = onnx.Constant dense<65535> : tensor<2xui16> +%3 = onnx.Constant dense<0.152590215> : tensor<2xf32> +%4 = onnx.Constant dense<0> : tensor<2x1x1x1xui16> +%5 = "onnx.DequantizeLinear"(%4, %3, %2) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x1xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x1xf32> +%6 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x128xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xf32> +%7 = "onnx.Mul"(%6, %5) : (tensor<2x1x1x128xf32>, tensor<2x1x1x1xf32>) -> tensor<2x1x1x128xf32> +%8 = "onnx.QuantizeLinear"(%7, %3, %2) {axis = 0 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<2x1x1x128xf32>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xui16> +%9 = "onnx.DequantizeLinear"(%8, %3, %2) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x128xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xf32> +return %9 : tensor<2x1x1x128xf32> +} +//----- +// 13) dq1-dq2(const input, per-axis length-2 on axis=0)-mul-q-dq. +// Keep Mul and QuantizeLinear present. +// CHECK-LABEL: func.func @test_removebinary_pattern7b +// CHECK: onnx.Mul +// CHECK: onnx.QuantizeLinear +func.func @test_removebinary_pattern7b(%arg0: tensor<2x1x1x128xui16>) -> tensor<2x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor<2xui16> +%1 = onnx.Constant dense<1.52590219E-5> : tensor<2xf32> +%2 = onnx.Constant dense<[65535, 1]> : tensor<2xui16> +%3 = onnx.Constant dense<0.152590215> : tensor<2xf32> +%4 = onnx.Constant dense<0> : tensor<2x1x1x1xui16> +%5 = "onnx.DequantizeLinear"(%4, %3, %2) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x1xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x1xf32> +%6 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x128xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xf32> +%7 = "onnx.Mul"(%6, %5) : (tensor<2x1x1x128xf32>, tensor<2x1x1x1xf32>) -> tensor<2x1x1x128xf32> +%8 = "onnx.QuantizeLinear"(%7, %3, %2) {axis = 0 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<2x1x1x128xf32>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xui16> +%9 = "onnx.DequantizeLinear"(%8, %3, %2) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x128xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xf32> +return %9 : tensor<2x1x1x128xf32> +} diff --git a/test/mlir/onnx/onnx_remove_sub.mlir b/test/mlir/onnx/onnx_remove_sub.mlir new file mode 100644 index 0000000000..e92b1907cd --- /dev/null +++ b/test/mlir/onnx/onnx_remove_sub.mlir @@ -0,0 +1,257 @@ +// RUN: onnx-mlir-opt --dq-binary-q-opt-onnx-to-onnx %s --split-input-file | FileCheck %s + +// 1) dq1-dq2(const input)-sub-q-dq. remove->sub,q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern1a +// CHECK: %[[ZP:.*]] = onnx.Constant dense<65535> : tensor +// CHECK-NOT: onnx.Sub +// CHECK-NOT: onnx.QuantizeLinear +// CHECK: return +// CHECK-NOT: onnx.DequantizeLinear +func.func @test_removebinary_pattern1a(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<39664> : tensor +%4 = onnx.Constant dense<2.57987776E-5> : tensor +%5 = "onnx.DequantizeLinear"(%2, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor +%6 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%7 = "onnx.Sub"(%6, %5) : (tensor<1x1x1x128xf32>, tensor) -> tensor<1x1x1x128xf32> +%8 = "onnx.QuantizeLinear"(%7, %1, %0) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%9 = "onnx.DequantizeLinear"(%8, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %9 : tensor<1x1x1x128xf32> +} + +// ----- +// 2) dq1-dq2(const input)-sub-q-dq. remove->sub,q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern1b +// CHECK: %[[ZP:.*]] = onnx.Constant dense<65535> : tensor +// CHECK-NOT: onnx.Sub +// CHECK-NOT: onnx.QuantizeLinear +// CHECK: return +// CHECK-NOT: onnx.DequantizeLinear +func.func @test_removebinary_pattern1b(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<39664> : tensor +%4 = onnx.Constant dense<2.57987776E-5> : tensor +%5 = "onnx.DequantizeLinear"(%2, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor +%6 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%7 = "onnx.Sub"(%5, %6) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +%8 = "onnx.QuantizeLinear"(%7, %1, %0) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%9 = "onnx.DequantizeLinear"(%8, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %9 : tensor<1x1x1x128xf32> +} + +// ----- +// 3) dq1-dq2(const input)-Sub-q-dq. remove->Sub,q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern1c +// CHECK-NOT: onnx.Sub +// CHECK-NOT: onnx.QuantizeLinear +func.func @test_removebinary_pattern1c(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = onnx.Constant dense<0> : tensor +%5 = "onnx.Identity"(%4) : (tensor) -> tensor +%6 = "onnx.DequantizeLinear"(%5, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor +%7 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%8 = "onnx.Sub"(%7, %6) : (tensor<1x1x1x128xf32>, tensor) -> tensor<1x1x1x128xf32> +%9 = "onnx.QuantizeLinear"(%8, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%10 = "onnx.DequantizeLinear"(%9, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %10 : tensor<1x1x1x128xf32> +} + +// ----- +// 4) dq1-dq2(const input)-Sub-q-dq. remove->Sub,q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern1d +// CHECK-NOT: onnx.Sub +// CHECK-NOT: onnx.QuantizeLinear +func.func @test_removebinary_pattern1d(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = onnx.Constant dense<0> : tensor +%5 = "onnx.Identity"(%4) : (tensor) -> tensor +%6 = "onnx.DequantizeLinear"(%5, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor +%7 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%8 = "onnx.Sub"(%6, %7) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +%9 = "onnx.QuantizeLinear"(%8, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%10 = "onnx.DequantizeLinear"(%9, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %10 : tensor<1x1x1x128xf32> +} + +//----- +// 5) dq1-const-add-q-dq. remove->add, q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern2a +// CHECK: %[[ZP:.*]] = onnx.Constant dense<102> : tensor +// CHECK-NOT: onnx.Add +// CHECK-NOT: onnx.QuantizeLinear +// CHECK: return +// CHECK-NOT: onnx.DequantizeLinear +func.func @test_removebinary_pattern2a(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<101> : tensor +%1 = onnx.Constant dense<1.000000e+00> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = onnx.Constant dense<1.000000e+00> : tensor +%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%6 = "onnx.Sub"(%5, %4) : (tensor<1x1x1x128xf32>, tensor) -> tensor<1x1x1x128xf32> +%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %8 : tensor<1x1x1x128xf32> +} +//----- +// 6) const-dq1-sub-q-dq. remove->sub,q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern2b +// CHECK-NOT: onnx.Sub +// CHECK-NOT: onnx.QuantizeLinear +// CHECK: return +// CHECK-NOT: onnx.DequantizeLinear +func.func @test_removebinary_pattern2b(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = onnx.Constant dense<-1.000000e+04> : tensor +%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%6 = "onnx.Sub"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %8 : tensor<1x1x1x128xf32> +} +//----- +// 7) const-dq1-sub-q-dq. kval=0. remove->sub,q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern3a +// CHECK-NOT: onnx.Sub +// CHECK-NOT: onnx.QuantizeLinear +// CHECK: return +// CHECK-NOT: onnx.DequantizeLinear +func.func @test_removebinary_pattern3a(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = onnx.Constant dense<0.000000e+00> : tensor +%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%6 = "onnx.Sub"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %8 : tensor<1x1x1x128xf32> +} +//----- +// 8) const-dq1-Sub-q-dq. dst_scale=0. remove->Sub,q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern3b +// CHECK: onnx.Sub +// CHECK: onnx.QuantizeLinear +func.func @test_removebinary_pattern3b(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<0.000000e+00> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = onnx.Constant dense<-1.000000e+04> : tensor +%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%6 = "onnx.Sub"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %8 : tensor<1x1x1x128xf32> +} +//----- +// 9) dq1-dq2(const input)-sub-q-dq. remove->sub,q-dq. +// CHECK-LABEL: func.func @test_removebinary_pattern4 +// CHECK-NOT: onnx.Sub +// CHECK: onnx.QuantizeLinear +func.func @test_removebinary_pattern4(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<39664> : tensor +%4 = onnx.Constant dense<2.57987776E-5> : tensor +%5 = "onnx.DequantizeLinear"(%2, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor +%6 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%7 = "onnx.Sub"(%5, %6) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +%8 = "onnx.QuantizeLinear"(%7, %1, %0) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%9 = "onnx.DequantizeLinear"(%8, %4, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %9 : tensor<1x1x1x128xf32> +} +//----- +// 10) const-dq1-Sub-tanh. remove->none +// CHECK-LABEL: func.func @test_removebinary_pattern5 +// CHECK: onnx.Sub +// CHECK: onnx.Tanh +func.func @test_removebinary_pattern5(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<39664> : tensor +%4 = onnx.Constant dense<2.57987776E-5> : tensor +%5 = "onnx.DequantizeLinear"(%2, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor +%6 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%7 = "onnx.Sub"(%5, %6) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +%8 = "onnx.Tanh"(%7) : (tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +return %8 : tensor<1x1x1x128xf32> +} +//----- +// 11) dq1-dq2-sub-q-dq1-dq2-mul-Q-DQ. multi-use of scale and zp of dq-act before binary op. remove->mul, sub +// CHECK-LABEL: func.func @test_removebinary_pattern6 +// CHECK-NOT: onnx.Sub +// CHECK-NOT: onnx.Sub +func.func @test_removebinary_pattern6(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor +%1 = onnx.Constant dense<1.52590219E-5> : tensor +%2 = onnx.Constant dense<65535> : tensor +%3 = onnx.Constant dense<0.152590215> : tensor +%4 = onnx.Constant dense<39664> : tensor +%5 = onnx.Constant dense<2.57987776E-5> : tensor +%6 = "onnx.DequantizeLinear"(%2, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor +%7 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%8 = "onnx.Mul"(%6, %7) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> +%9 = "onnx.QuantizeLinear"(%8, %1, %0) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%10 = "onnx.DequantizeLinear"(%9, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +%11 = "onnx.DequantizeLinear"(%0, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor +%12 = "onnx.Sub"(%10, %11) : (tensor<1x1x1x128xf32>, tensor) -> tensor<1x1x1x128xf32> +%13 = "onnx.QuantizeLinear"(%12, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> +%14 = "onnx.DequantizeLinear"(%13, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> +return %14 : tensor<1x1x1x128xf32> +} +//----- +// 12) dq1-dq2(const input, per-axis length-2 on axis=0)-mul-q-dq. +// vectors wiht same values -> fusion +// CHECK-LABEL: func.func @test_removebinary_pattern7a +// CHECK-NOT: onnx.Sub +// CHECK-NOT: onnx.QuantizeLinear +func.func @test_removebinary_pattern7a(%arg0: tensor<2x1x1x128xui16>) -> tensor<2x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor<2xui16> +%1 = onnx.Constant dense<1.52590219E-5> : tensor<2xf32> +%2 = onnx.Constant dense<65535> : tensor<2xui16> +%3 = onnx.Constant dense<0.152590215> : tensor<2xf32> +%4 = onnx.Constant dense<0> : tensor<2x1x1x1xui16> +%5 = "onnx.DequantizeLinear"(%4, %3, %2) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x1xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x1xf32> +%6 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x128xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xf32> +%7 = "onnx.Sub"(%6, %5) : (tensor<2x1x1x128xf32>, tensor<2x1x1x1xf32>) -> tensor<2x1x1x128xf32> +%8 = "onnx.QuantizeLinear"(%7, %3, %2) {axis = 0 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<2x1x1x128xf32>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xui16> +%9 = "onnx.DequantizeLinear"(%8, %3, %2) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x128xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xf32> +return %9 : tensor<2x1x1x128xf32> +} +//----- +// +// 13) dq1-dq2(const input, per-axis length-2 on axis=0)-mul-q-dq. +// vectors wiht different values -> no fusion +// CHECK-LABEL: func.func @test_removebinary_pattern7b +// CHECK: onnx.Sub +// CHECK: onnx.QuantizeLinear +func.func @test_removebinary_pattern7b(%arg0: tensor<2x1x1x128xui16>) -> tensor<2x1x1x128xf32> { +%0 = onnx.Constant dense<0> : tensor<2xui16> +%1 = onnx.Constant dense<1.52590219E-5> : tensor<2xf32> +%2 = onnx.Constant dense<[65535, 1]> : tensor<2xui16> +%3 = onnx.Constant dense<0.152590215> : tensor<2xf32> +%4 = onnx.Constant dense<0> : tensor<2x1x1x1xui16> +%5 = "onnx.DequantizeLinear"(%4, %3, %2) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x1xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x1xf32> +%6 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x128xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xf32> +%7 = "onnx.Sub"(%6, %5) : (tensor<2x1x1x128xf32>, tensor<2x1x1x1xf32>) -> tensor<2x1x1x128xf32> +%8 = "onnx.QuantizeLinear"(%7, %3, %2) {axis = 0 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<2x1x1x128xf32>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xui16> +%9 = "onnx.DequantizeLinear"(%8, %3, %2) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x128xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xf32> +return %9 : tensor<2x1x1x128xf32> +}