From ab1c32d5b48eb8e24d6bb8ddf7cd263a45d3c60c Mon Sep 17 00:00:00 2001 From: sushmita Date: Mon, 10 Nov 2025 14:59:58 +0530 Subject: [PATCH 1/5] remove_binary_update --- src/Dialect/ONNX/Transforms/DQBinaryQOpt.cpp | 386 ++++++++++++++----- test/mlir/onnx/onnx_remove_add.mlir | 255 ------------ test/mlir/onnx/onnx_remove_div.mlir | 249 ------------ test/mlir/onnx/onnx_remove_mul.mlir | 254 ------------ test/mlir/onnx/onnx_remove_sub.mlir | 257 ------------ 5 files changed, 290 insertions(+), 1111 deletions(-) delete mode 100644 test/mlir/onnx/onnx_remove_add.mlir delete mode 100644 test/mlir/onnx/onnx_remove_div.mlir delete mode 100644 test/mlir/onnx/onnx_remove_mul.mlir delete mode 100644 test/mlir/onnx/onnx_remove_sub.mlir diff --git a/src/Dialect/ONNX/Transforms/DQBinaryQOpt.cpp b/src/Dialect/ONNX/Transforms/DQBinaryQOpt.cpp index 172453da0b..8c3f562612 100644 --- a/src/Dialect/ONNX/Transforms/DQBinaryQOpt.cpp +++ b/src/Dialect/ONNX/Transforms/DQBinaryQOpt.cpp @@ -233,62 +233,89 @@ static void updateInitializer(mlir::PatternRewriter &rewriter, } } -static LogicalResult tryRemoveQThenDQChain( - mlir::PatternRewriter &rewriter, mlir::ONNXDequantizeLinearOp dqOp) { +// Returns success() iff Q->DQ is *removable* under strict checks. +// If doRewrite==true, it also *applies* the rewrite for this DQ (replaces DQ +// with Q.x). +static mlir::LogicalResult tryRemoveQThenDQChain( + mlir::PatternRewriter &rewriter, mlir::ONNXDequantizeLinearOp dqOp, + bool doRewrite) { using namespace mlir; - // Match Q -> DQ + // Match direct Q -> DQ auto qOp = dqOp.getX().template getDefiningOp(); - if (!qOp) { + if (!qOp) return failure(); - } // 1) Axis / block_size must match - if (qOp.getAxis() != dqOp.getAxis()) { + if (qOp.getAxis() != dqOp.getAxis()) return failure(); - } - if (qOp.getBlockSize() != dqOp.getBlockSize()) { + if (qOp.getBlockSize() != dqOp.getBlockSize()) return failure(); - } // 2) Zero-points must match scalars/splats auto zpQ = getElementAttributeFromONNXValue(qOp.getYZeroPoint()); auto zpDQ = getElementAttributeFromONNXValue(dqOp.getXZeroPoint()); - if (!zpQ || !zpDQ || zpQ != zpDQ) { + if (!zpQ || !zpDQ || zpQ != zpDQ) return failure(); - } // 3) Scales must match scalars/splats auto sQ = getElementAttributeFromONNXValue(qOp.getYScale()); auto sDQ = getElementAttributeFromONNXValue(dqOp.getXScale()); - if (!sQ || !sDQ || sQ != sDQ) { + if (!sQ || !sDQ || sQ != sDQ) return failure(); - } - // 4) Data type consistency: input of Q and output of DQ must have same elem - // type. + // 4) Element type parity between Q.x and DQ.y auto qInTypeOp = qOp.getX().getType(); auto dqOutTypeOp = dqOp.getResult().getType(); + auto qInT = qInTypeOp.dyn_cast(); + auto dqOutT = dqOutTypeOp.dyn_cast(); + if (!qInT || !dqOutT) + return failure(); + if (dqOutT.getElementType() != qInT.getElementType()) + return failure(); - if (auto qInTensorType = qInTypeOp.dyn_cast()) { - if (auto dqOutTensorType = dqOutTypeOp.dyn_cast()) { - if (dqOutTensorType.getElementType() != qInTensorType.getElementType()) { - return failure(); + // If only checking removability, stop here. + if (!doRewrite) + return success(); + + // Rewrite: replace DQ with Q's float input. + rewriter.replaceOp(dqOp, qOp.getX()); + return success(); +} + +// If doRewrite=false: returns true iff *any* removable DQ user exists (no +// mutation). If doRewrite=true : performs removals and returns true iff it +// removed at least one DQ. Also erases Q if it becomes dead after removals. +static bool Remove_Q_Plus_DQ( + mlir::PatternRewriter &rewriter, ONNXQuantizeLinearOp qOp, bool doRewrite) { + using namespace mlir; + if (!qOp) + return false; + + int removableCount = 0; + int removedCount = 0; + + // Safe iteration while potentially mutating (when doRewrite==true) + auto users = llvm::make_early_inc_range(qOp.getY().getUsers()); + for (Operation *user : users) { + if (auto tailDQ = llvm::dyn_cast(user)) { + if (succeeded( + tryRemoveQThenDQChain(rewriter, tailDQ, /*doRewrite*/ false))) { + ++removableCount; + if (doRewrite) { + if (succeeded( + tryRemoveQThenDQChain(rewriter, tailDQ, /*doRewrite*/ true))) + ++removedCount; + } } - } else { - return failure(); } - } else { - return failure(); } - // Replace DQ with Q's float input; erase Q if it becomes dead. - rewriter.replaceOp(dqOp, qOp.getX()); - if (qOp->use_empty()) { + if (doRewrite && qOp->use_empty()) { rewriter.eraseOp(qOp); } - return success(); + return doRewrite ? (removedCount > 0) : (removableCount > 0); } static bool isValuePreservingOp(mlir::Operation *op) { @@ -304,14 +331,25 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { private: struct MatchState { - ONNXDequantizeLinearOp dequantActivationOp = nullptr; - double kValue = 0.0; // Dequantized value of the constant node - double dstScale = 0.0; // Destination node's scale - int64_t dstZeroPoint = 0; // Destination node's zero-point - double newScale = - 0.0; // New scale of the destination node after constant folding - int64_t newZp = - 0; // New zero-point of the estination node after constant folding + ONNXDequantizeLinearOp dequantActivationOfBinOp = + nullptr; // BinaryOP parent op + ONNXQuantizeLinearOp quantOutputOfBinOp = nullptr; // BinaryOp child op + + // Destination/source ops picked by find_destination_node() + mlir::Operation *dstNode = + nullptr; // DQ when folding into DQ, or Q when folding into Q + mlir::Operation *srcNode = nullptr; + + // Current destination params (read before fold) + double dstScale = 0.0; + int64_t dstZeroPoint = 0; + + // New params to write after fold + double newScale = 0.0; + int64_t newZp = 0; + + // Constant value folded + double kValue = 0.0; }; LogicalResult match_qdq(mlir::PatternRewriter &rewriter, MatchState &state, @@ -323,36 +361,39 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { // Case 1: Direct ConstantOp as input to the DQ. if (auto constOp = dq1.getX().getDefiningOp()) { constantDqOp = dq1; - state.dequantActivationOp = dq2; + state.dequantActivationOfBinOp = dq2; constantSourceOp = constOp; } else if (auto constOp = dq2.getX().getDefiningOp()) { constantDqOp = dq2; - state.dequantActivationOp = dq1; + state.dequantActivationOfBinOp = dq1; constantSourceOp = constOp; } // Case 2: The input to the DQ op comes from a chain whose input is a // constant. - else if (auto intermediateOp = dq1.getX().getDefiningOp()) { - if (isValuePreservingOp(intermediateOp)) { - if (auto constOp = - intermediateOp->getOperand(0).getDefiningOp()) { - constantDqOp = dq1; - state.dequantActivationOp = dq2; - constantSourceOp = constOp; + else { + if (auto intermediateOp = dq1.getX().getDefiningOp()) { + if (isValuePreservingOp(intermediateOp)) { + if (auto constOp = intermediateOp->getOperand(0) + .getDefiningOp()) { + constantDqOp = dq1; + state.dequantActivationOfBinOp = dq2; + constantSourceOp = constOp; + } } } - } else if (auto intermediateOp = dq2.getX().getDefiningOp()) { - if (isValuePreservingOp(intermediateOp)) { - if (auto constOp = - intermediateOp->getOperand(0).getDefiningOp()) { - constantDqOp = dq2; - state.dequantActivationOp = dq1; - constantSourceOp = constOp; + if (auto intermediateOp = dq2.getX().getDefiningOp()) { + if (isValuePreservingOp(intermediateOp)) { + if (auto constOp = intermediateOp->getOperand(0) + .getDefiningOp()) { + constantDqOp = dq2; + state.dequantActivationOfBinOp = dq1; + constantSourceOp = constOp; + } } } } - if (!constantDqOp || !constantSourceOp || !state.dequantActivationOp) { + if (!constantDqOp || !constantSourceOp || !state.dequantActivationOfBinOp) { return failure(); } @@ -383,24 +424,32 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { Value lhs = binaryOp.getOperand(0); Value rhs = binaryOp.getOperand(1); + Value out = binaryOp->getResult(0); + state.quantOutputOfBinOp = + dyn_cast(*out.getUsers().begin()); + // auto qOut = getUniqueQuantUserOrNull(binaryOp->getResult(0)); + // if (!qOut) + // return rewriter.notifyMatchFailure(binaryOp, + // "binary result must have exactly one ONNXQuantizeLinearOp user"); + // state.quantOutputOfBinOp = qOut; // -------- Case A: lhs is DQ, rhs is Constant -------- if (auto dqOp = lhs.getDefiningOp()) { if (auto constOp = rhs.getDefiningOp()) { - state.dequantActivationOp = dqOp; + state.dequantActivationOfBinOp = dqOp; constantOp = constOp; } } // -------- Case A reversed -------- else if (auto dqOp = rhs.getDefiningOp()) { if (auto constOp = lhs.getDefiningOp()) { - state.dequantActivationOp = dqOp; + state.dequantActivationOfBinOp = dqOp; constantOp = constOp; } } // -------- Fill state values for Case A and Case A reversed -------- - if (state.dequantActivationOp && constantOp) { + if (state.dequantActivationOfBinOp && constantOp) { auto kValueOpt = getScalarTensorValue(constantOp); if (!kValueOpt) { return rewriter.notifyMatchFailure( @@ -420,7 +469,7 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { return failure(); } - LogicalResult check_needed_values(mlir::PatternRewriter &rewriter, + /*LogicalResult check_needed_values(mlir::PatternRewriter &rewriter, const MatchState &state, Operation *binaryOp) const { if (state.kValue == 0.0) { if (isa(binaryOp)) { @@ -438,21 +487,66 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { } } return success(); + }*/ + + LogicalResult check_needed_values(mlir::PatternRewriter &rewriter, + const MatchState &state, Operation *binaryOp) const { + const bool dstIsDQ = llvm::isa(state.dstNode); + const bool dstIsQ = llvm::isa(state.dstNode); + + // scale_new = scale / k (for Div when folding into DQ) OR scale_new = + // scale / k (for Mul when folding into Q) Avoid division by zero when k == + // 0 in those cases. + if (state.kValue == 0.0) { + if (dstIsDQ && llvm::isa(binaryOp)) { + return rewriter.notifyMatchFailure(binaryOp, + "when opType is Div, remove binary op only if k_value is not zero, " + "to avoid ZeroDivisionError"); + } + else if (dstIsQ && llvm::isa(binaryOp)) { + return rewriter.notifyMatchFailure(binaryOp, + "when opType is Mul, remove binary op only if k_value is not zero, " + "to avoid ZeroDivisionError"); + } + } + + // k/scale is used for Add/Sub to update zero_point. + // Avoid division by zero when dstScale == 0. + if (state.dstScale == 0.0 && (llvm::isa(binaryOp))) { + return rewriter.notifyMatchFailure(binaryOp, + "when opType is Add or Sub, remove binary op only if scale is not " + "zero, to avoid ZeroDivisionError"); + } + + return mlir::success(); } static bool compute_new_scale_and_zp_values(MatchState &state) { double newScale = state.dstScale; double newZpFloat = static_cast(state.dstZeroPoint); const double kVal = state.kValue; + const bool dstIsDQ = llvm::isa(state.dstNode); if constexpr (std::is_same_v) { - newZpFloat -= (kVal / newScale); + if (dstIsDQ) + newZpFloat -= (kVal / newScale); + else + newZpFloat += (kVal / newScale); } else if constexpr (std::is_same_v) { - newZpFloat += (kVal / newScale); + if (dstIsDQ) + newZpFloat += (kVal / newScale); + else + newZpFloat -= (kVal / newScale); } else if constexpr (std::is_same_v) { - newScale *= kVal; + if (dstIsDQ) + newScale *= kVal; + else + newScale /= kVal; } else if constexpr (std::is_same_v) { - newScale /= kVal; + if (dstIsDQ) + newScale /= kVal; + else + newScale *= kVal; } else { static_assert(std::is_same_v || std::is_same_v || @@ -470,6 +564,79 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { return true; } + static ONNXQuantizeLinearOp getSingleQuantizeUser(Value v) { + ONNXQuantizeLinearOp q = nullptr; + for (Operation *u : v.getUsers()) { + if (auto cand = dyn_cast(u)) { + if (q) + return nullptr; // more than one Quantize user + q = cand; + } + } + return q; + } + + LogicalResult findDestinationNode( + mlir::PatternRewriter &rewriter, MatchState &state, Operation *op) const { + auto dq = state.dequantActivationOfBinOp; + if (!dq) + return rewriter.notifyMatchFailure( + op, "dequantActivationOfBinOp not set in MatchState"); + + // Producer Quantize of DQ.x (may be null if DQ consumes a block arg or + // non-Q) + auto q = dq.getX().template getDefiningOp(); + + // Non-mutating probe: removable only if Q exists and matches + bool removableQDQ = false; + if (q) + removableQDQ = Remove_Q_Plus_DQ(rewriter, q, /*doRewrite=*/false); + + // Branch detection on distinct users + auto hasBranchOnValue = [](mlir::Value v) { + llvm::SmallPtrSet uniq; + for (auto *u : v.getUsers()) + uniq.insert(u); + return uniq.size() > 1; + }; + const bool branch_after = hasBranchOnValue(dq.getY()); + const bool branch_before = q ? hasBranchOnValue(q.getY()) : false; + const bool branch_on_dequant_activation = branch_after || branch_before; + + // If we cannot remove Q->DQ (or there is branching), fold into DQ + if (!removableQDQ || branch_on_dequant_activation) { + state.dstNode = dq.getOperation(); // DQ + state.srcNode = state.quantOutputOfBinOp.getOperation(); // Q after binop + + auto scaleOpt = getScalarTensorValueFromVal(dq.getXScale()); + auto zpOpt = getScalarTensorValueFromVal(dq.getXZeroPoint()); + if (!scaleOpt || !zpOpt) + return rewriter.notifyMatchFailure( + dq, "DQ x_scale/x_zero_point must be scalar"); + state.dstScale = *scaleOpt; + state.dstZeroPoint = *zpOpt; + return success(); + } + + // Else: fold into the Quantize after the binop + auto qOut = state.quantOutputOfBinOp; + if (!qOut) + return rewriter.notifyMatchFailure( + op, "expected a unique Quantize user of the binary result"); + + state.dstNode = qOut.getOperation(); // Q + state.srcNode = state.dequantActivationOfBinOp.getOperation(); // DQ + + auto scaleOpt = getScalarTensorValueFromVal(qOut.getYScale()); + auto zpOpt = getScalarTensorValueFromVal(qOut.getYZeroPoint()); + if (!scaleOpt || !zpOpt) + return rewriter.notifyMatchFailure( + qOut, "Quantize y_scale/y_zero_point must be scalar"); + state.dstScale = *scaleOpt; + state.dstZeroPoint = *zpOpt; + return success(); + } + public: LogicalResult matchAndRewrite( BinOp op, PatternRewriter &rewriter) const override { @@ -478,8 +645,8 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { if (!op->hasOneUse()) { return rewriter.notifyMatchFailure(op, "pattern requires a single user"); } - auto quantOutputOp = dyn_cast(*op->user_begin()); - if (!quantOutputOp) { + auto quantOutputOfBinOp = dyn_cast(*op->user_begin()); + if (!quantOutputOfBinOp) { return rewriter.notifyMatchFailure( op, "expected user to be an ONNXQuantizeLinearOp"); } @@ -495,58 +662,85 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { "has const scalar value "); } - // Store the value of the scale and zero point of the destination node - { - Value scaleVal = state.dequantActivationOp.getXScale(); - Value zpVal = state.dequantActivationOp.getXZeroPoint(); - auto scale_value_opt = getScalarTensorValueFromVal(scaleVal); - auto zp_value_opt = getScalarTensorValueFromVal(zpVal); - if (!scale_value_opt || !zp_value_opt) { - return rewriter.notifyMatchFailure(state.dequantActivationOp, - " must be a scalar value or a list of same value"); - } - state.dstScale = scale_value_opt.value(); - state.dstZeroPoint = zp_value_opt.value(); + // STEP 3 + if (failed(findDestinationNode(rewriter, state, op))) { + return failure(); } - // STEP 3 + // STEP 4 if (failed(check_needed_values(rewriter, state, op))) { return failure(); } - // STEP 4 -Modify + // STEP 5 -Modify if (!compute_new_scale_and_zp_values(state)) { return failure(); } - // STEP 5: call initializer based on the binary op - ONNXDequantizeLinearOp dqAct = state.dequantActivationOp; - if constexpr (std::is_same_v || - std::is_same_v) { - Value zpVal = dqAct.getXZeroPoint(); - updateInitializer(rewriter, dqAct.getOperation(), zpVal, - static_cast(state.newZp)); - - } else if constexpr (std::is_same_v || - std::is_same_v) { - Value scaleVal = dqAct.getXScale(); - updateInitializer( - rewriter, dqAct.getOperation(), scaleVal, state.newScale); + // STEP 6: call initializer based on the binary op + { + auto *dst = state.dstNode; + if (!dst) + return rewriter.notifyMatchFailure(op, "dstNode not set"); + + if (auto dqDst = llvm::dyn_cast(dst)) { + if constexpr (std::is_same_v || + std::is_same_v) { + // Update zero-point at DQ.x + Value xZp = dqDst.getXZeroPoint(); + updateInitializer(rewriter, dqDst.getOperation(), xZp, + static_cast(state.newZp)); + } else if constexpr (std::is_same_v || + std::is_same_v) { + // Update scale at DQ.x + Value xScale = dqDst.getXScale(); + updateInitializer( + rewriter, dqDst.getOperation(), xScale, state.newScale); + } + } else if (auto qDst = llvm::dyn_cast(dst)) { + if constexpr (std::is_same_v || + std::is_same_v) { + // Update zero-point at Q.y + Value yZp = qDst.getYZeroPoint(); + updateInitializer(rewriter, qDst.getOperation(), yZp, + static_cast(state.newZp)); + } else if constexpr (std::is_same_v || + std::is_same_v) { + // Update scale at Q.y + Value yScale = qDst.getYScale(); + updateInitializer( + rewriter, qDst.getOperation(), yScale, state.newScale); + } + } else { + return rewriter.notifyMatchFailure( + op, "dstNode is neither Dequantize nor Quantize"); + } } // STEP 6: Remove binary op - rewriter.replaceOp(op, dqAct.getResult()); + rewriter.replaceOp(op, state.dequantActivationOfBinOp.getResult()); // STEP 7: Remove Q->DQ chain - - // prevent iterating and removing elements - auto users = llvm::make_early_inc_range(quantOutputOp.getY().getUsers()); - for (Operation *user : users) { - if (auto tailDQ = llvm::dyn_cast(user)) { - (void)tryRemoveQThenDQChain(rewriter, tailDQ); + ONNXQuantizeLinearOp chainStartQ = nullptr; + + if (llvm::isa(state.dstNode)) { + // Folding happened in the Dequantize: chain start is the Quantize after + // the BinOp + chainStartQ = state.quantOutputOfBinOp; // set earlier during match + } else if (llvm::isa(state.dstNode)) { + // Folding happened in the Quantize: chain start is the Quantize feeding + // DQ.x + if (auto dqAct = state.dequantActivationOfBinOp) { + chainStartQ = + dqAct.getX().template getDefiningOp(); } } + // Run the cleanup if we found a Quantize + if (chainStartQ) { + (void)Remove_Q_Plus_DQ(rewriter, chainStartQ, /*doRewrite=*/true); + } + return success(); } }; diff --git a/test/mlir/onnx/onnx_remove_add.mlir b/test/mlir/onnx/onnx_remove_add.mlir deleted file mode 100644 index b1baf41bb7..0000000000 --- a/test/mlir/onnx/onnx_remove_add.mlir +++ /dev/null @@ -1,255 +0,0 @@ -// RUN: onnx-mlir-opt --dq-binary-q-opt-onnx-to-onnx %s --split-input-file | FileCheck %s - -// 1) dq1-dq2(const input)-add-q-dq. remove->add,q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern1a -// CHECK: %[[ZP:.*]] = onnx.Constant dense<336> : tensor -// CHECK-NOT: onnx.Add -// CHECK-NOT: onnx.QuantizeLinear -// CHECK: return -// CHECK-NOT: onnx.DequantizeLinear -func.func @test_removebinary_pattern1a(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<39664> : tensor -%3 = onnx.Constant dense<40000> : tensor -%4 = "onnx.DequantizeLinear"(%2, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor -%5 = "onnx.DequantizeLinear"(%arg0, %1, %3) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%6 = "onnx.Add"(%5, %4) : (tensor<1x1x1x128xf32>, tensor) -> tensor<1x1x1x128xf32> -%7 = "onnx.QuantizeLinear"(%6, %1, %0) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%8 = "onnx.DequantizeLinear"(%7, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %8 : tensor<1x1x1x128xf32> -} - -// ----- -// 2) dq1-dq2(const input)-add-q-dq. remove->add,q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern1b -// CHECK: %[[ZP:.*]] = onnx.Constant dense<336> : tensor -// CHECK-NOT: onnx.Add -// CHECK-NOT: onnx.QuantizeLinear -// CHECK: return -// CHECK-NOT: onnx.DequantizeLinear -func.func @test_removebinary_pattern1b(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<39664> : tensor -%3 = onnx.Constant dense<40000> : tensor -%4 = "onnx.DequantizeLinear"(%2, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor -%5 = "onnx.DequantizeLinear"(%arg0, %1, %3) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%6 = "onnx.Add"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -%7 = "onnx.QuantizeLinear"(%6, %1, %0) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%8 = "onnx.DequantizeLinear"(%7, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %8 : tensor<1x1x1x128xf32> -} - -// ----- -// 3) dq1-dq2(const input)-Sub-q-dq. remove->Sub,q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern1c -// CHECK-NOT: onnx.Add -// CHECK-NOT: onnx.QuantizeLinear -func.func @test_removebinary_pattern1c(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = onnx.Constant dense<0> : tensor -%5 = "onnx.Identity"(%4) : (tensor) -> tensor -%6 = "onnx.DequantizeLinear"(%5, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor -%7 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%8 = "onnx.Add"(%7, %6) : (tensor<1x1x1x128xf32>, tensor) -> tensor<1x1x1x128xf32> -%9 = "onnx.QuantizeLinear"(%8, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%10 = "onnx.DequantizeLinear"(%9, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %10 : tensor<1x1x1x128xf32> -} - -// ----- -// 4) dq1-dq2(const input)-Sub-q-dq. remove->Sub,q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern1d -// CHECK-NOT: onnx.Add -// CHECK-NOT: onnx.QuantizeLinear -func.func @test_removebinary_pattern1d(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = onnx.Constant dense<0> : tensor -%5 = "onnx.Identity"(%4) : (tensor) -> tensor -%6 = "onnx.DequantizeLinear"(%5, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor -%7 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%8 = "onnx.Add"(%6, %7) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -%9 = "onnx.QuantizeLinear"(%8, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%10 = "onnx.DequantizeLinear"(%9, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %10 : tensor<1x1x1x128xf32> -} -//----- -// 5) dq1-const-add-q-dq. remove->add, q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern2a -// CHECK: %[[ZP:.*]] = onnx.Constant dense<99> : tensor -// CHECK-NOT: onnx.Add -// CHECK-NOT: onnx.QuantizeLinear -// CHECK: return -// CHECK-NOT: onnx.DequantizeLinear -func.func @test_removebinary_pattern2a(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<100> : tensor -%1 = onnx.Constant dense<1.000000e+01> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = onnx.Constant dense<1.000000e+00> : tensor -%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%6 = "onnx.Add"(%5, %4) : (tensor<1x1x1x128xf32>, tensor) -> tensor<1x1x1x128xf32> -%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %8 : tensor<1x1x1x128xf32> -} -//----- -// 6) const-dq1-add-q-dq. remove->add,q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern2b -// CHECK-NOT: onnx.Add -// CHECK-NOT: onnx.QuantizeLinear -// CHECK: return -// CHECK-NOT: onnx.DequantizeLinear -func.func @test_removebinary_pattern2b(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = onnx.Constant dense<-1.000000e+04> : tensor -%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%6 = "onnx.Add"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %8 : tensor<1x1x1x128xf32> -} -//----- -// 7) const-dq1-add-q-dq. kval=0. remove->add,q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern3a -// CHECK-NOT: onnx.Add -// CHECK-NOT: onnx.QuantizeLinear -// CHECK: return -// CHECK-NOT: onnx.DequantizeLinear -func.func @test_removebinary_pattern3a(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = onnx.Constant dense<0.000000e+00> : tensor -%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%6 = "onnx.Add"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %8 : tensor<1x1x1x128xf32> -} -//----- -// 8) const-dq1-Sub-q-dq. dst_scale=0. remove->Sub,q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern3b -// CHECK: onnx.Add -// CHECK: onnx.QuantizeLinear -func.func @test_removebinary_pattern3b(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<0.000000e+00> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = onnx.Constant dense<-1.000000e+04> : tensor -%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%6 = "onnx.Add"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %8 : tensor<1x1x1x128xf32> -} -//----- -// 9) dq1-dq2(const input)-add-q-dq. remove->add,q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern4 -// CHECK-NOT: onnx.Add -// CHECK: onnx.QuantizeLinear -func.func @test_removebinary_pattern4(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<39664> : tensor -%4 = onnx.Constant dense<2.57987776E-5> : tensor -%5 = "onnx.DequantizeLinear"(%2, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor -%6 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%7 = "onnx.Add"(%5, %6) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -%8 = "onnx.QuantizeLinear"(%7, %1, %0) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%9 = "onnx.DequantizeLinear"(%8, %4, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %9 : tensor<1x1x1x128xf32> -} -//----- -// 10) const-dq1-Sub-tanh. remove->none -// CHECK-LABEL: func.func @test_removebinary_pattern5 -// CHECK: onnx.Add -// CHECK: onnx.Tanh -func.func @test_removebinary_pattern5(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<39664> : tensor -%4 = onnx.Constant dense<2.57987776E-5> : tensor -%5 = "onnx.DequantizeLinear"(%2, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor -%6 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%7 = "onnx.Add"(%5, %6) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -%8 = "onnx.Tanh"(%7) : (tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -return %8 : tensor<1x1x1x128xf32> -} -//----- -// 11) dq1-dq2-add-q-dq1-dq2-mul-Q-DQ. multi-use of scale and zp of dq-act before binary op. remove->mul, add -// CHECK-LABEL: func.func @test_removebinary_pattern6 -// CHECK-NOT: onnx.Add -// CHECK-NOT: onnx.Add -func.func @test_removebinary_pattern6(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = onnx.Constant dense<39664> : tensor -%5 = onnx.Constant dense<2.57987776E-5> : tensor -%6 = "onnx.DequantizeLinear"(%2, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor -%7 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%8 = "onnx.Div"(%6, %7) {onnx_node_name = "/bert/Sub"} : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -%9 = "onnx.QuantizeLinear"(%8, %1, %0) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%10 = "onnx.DequantizeLinear"(%9, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%11 = "onnx.DequantizeLinear"(%0, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor -%12 = "onnx.Add"(%10, %11) : (tensor<1x1x1x128xf32>, tensor) -> tensor<1x1x1x128xf32> -%13 = "onnx.QuantizeLinear"(%12, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%14 = "onnx.DequantizeLinear"(%13, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %14 : tensor<1x1x1x128xf32> -} -//----- -// -// 12) dq1-dq2(const input, per-axis length-2 on axis=0)-mul-q-dq. -// vectors wiht same values -> fusion -// CHECK-LABEL: func.func @test_removebinary_pattern7a -// CHECK-NOT: onnx.Add -// CHECK-NOT: onnx.QuantizeLinear -func.func @test_removebinary_pattern7a(%arg0: tensor<2x1x1x128xui16>) -> tensor<2x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor<2xui16> -%1 = onnx.Constant dense<1.52590219E-5> : tensor<2xf32> -%2 = onnx.Constant dense<65535> : tensor<2xui16> -%3 = onnx.Constant dense<0.152590215> : tensor<2xf32> -%4 = onnx.Constant dense<0> : tensor<2x1x1x1xui16> -%5 = "onnx.DequantizeLinear"(%4, %3, %2) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x1xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x1xf32> -%6 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x128xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xf32> -%7 = "onnx.Add"(%6, %5) : (tensor<2x1x1x128xf32>, tensor<2x1x1x1xf32>) -> tensor<2x1x1x128xf32> -%8 = "onnx.QuantizeLinear"(%7, %3, %2) {axis = 0 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<2x1x1x128xf32>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xui16> -%9 = "onnx.DequantizeLinear"(%8, %3, %2) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x128xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xf32> -return %9 : tensor<2x1x1x128xf32> -} -//----- -// -// 13) dq1-dq2(const input, per-axis length-2 on axis=0)-mul-q-dq. -// vectors wiht different values -> no fusion -// CHECK-LABEL: func.func @test_removebinary_pattern7b -// CHECK: onnx.Add -// CHECK: onnx.QuantizeLinear -func.func @test_removebinary_pattern7b(%arg0: tensor<2x1x1x128xui16>) -> tensor<2x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor<2xui16> -%1 = onnx.Constant dense<1.52590219E-5> : tensor<2xf32> -%2 = onnx.Constant dense<[65535, 1]> : tensor<2xui16> -%3 = onnx.Constant dense<0.152590215> : tensor<2xf32> -%4 = onnx.Constant dense<0> : tensor<2x1x1x1xui16> -%5 = "onnx.DequantizeLinear"(%4, %3, %2) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x1xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x1xf32> -%6 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x128xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xf32> -%7 = "onnx.Add"(%6, %5) : (tensor<2x1x1x128xf32>, tensor<2x1x1x1xf32>) -> tensor<2x1x1x128xf32> -%8 = "onnx.QuantizeLinear"(%7, %3, %2) {axis = 0 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<2x1x1x128xf32>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xui16> -%9 = "onnx.DequantizeLinear"(%8, %3, %2) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x128xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xf32> -return %9 : tensor<2x1x1x128xf32> -} diff --git a/test/mlir/onnx/onnx_remove_div.mlir b/test/mlir/onnx/onnx_remove_div.mlir deleted file mode 100644 index d73b65846f..0000000000 --- a/test/mlir/onnx/onnx_remove_div.mlir +++ /dev/null @@ -1,249 +0,0 @@ -// RUN: onnx-mlir-opt --dq-binary-q-opt-onnx-to-onnx %s -split-input-file | FileCheck %s - -// 1) dq1-dq2(const input)-div-q-dq. remove->div,q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern1a -// CHECK: %[[SCALE:.*]] = onnx.Constant dense<-1.52590218E-9> : tensor -// CHECK-NOT: onnx.Div -// CHECK-NOT: onnx.QuantizeLinear -// CHECK: return -// CHECK-NOT: onnx.DequantizeLinear -func.func @test_removebinary_pattern1a(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = "onnx.DequantizeLinear"(%0, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor -%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%6 = "onnx.Div"(%5, %4) : (tensor<1x1x1x128xf32>, tensor) -> tensor<1x1x1x128xf32> -%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %8 : tensor<1x1x1x128xf32> -} -//----- -// 2) dq1-dq2(const input)-div-q-dq. remove->div,q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern1b -// CHECK: %[[SCALE:.*]] = onnx.Constant dense<-1.52590218E-9> : tensor -// CHECK-NOT: onnx.Div -// CHECK-NOT: onnx.QuantizeLinear -// CHECK: return -// CHECK-NOT: onnx.DequantizeLinear -func.func @test_removebinary_pattern1b(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = "onnx.DequantizeLinear"(%0, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor -%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%6 = "onnx.Div"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %8 : tensor<1x1x1x128xf32> -} -//----- -// 3) dq1-dq2(const input)-div-q-dq. remove->div,q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern1c -// CHECK-NOT: onnx.Div -// CHECK-NOT: onnx.QuantizeLinear -func.func @test_removebinary_pattern1c(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = onnx.Constant dense<0> : tensor -%5 = "onnx.Identity"(%4) : (tensor) -> tensor -%6 = "onnx.DequantizeLinear"(%5, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor -%7 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%8 = "onnx.Div"(%7, %6) : (tensor<1x1x1x128xf32>, tensor) -> tensor<1x1x1x128xf32> -%9 = "onnx.QuantizeLinear"(%8, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%10 = "onnx.DequantizeLinear"(%9, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %10 : tensor<1x1x1x128xf32> -} -//----- -// 4) dq1-dq2(const input)-div-q-dq. remove->div,q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern1d -// CHECK-NOT: onnx.Div -// CHECK-NOT: onnx.QuantizeLinear -func.func @test_removebinary_pattern1d(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = onnx.Constant dense<0> : tensor -%5 = "onnx.Identity"(%4) : (tensor) -> tensor -%6 = "onnx.DequantizeLinear"(%5, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor -%7 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%8 = "onnx.Div"(%6, %7) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -%9 = "onnx.QuantizeLinear"(%8, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%10 = "onnx.DequantizeLinear"(%9, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %10 : tensor<1x1x1x128xf32> -} -//----- -// 5) dq1-const-mul-q-dq. remove->mul,q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern2a -// CHECK: %[[SCALE:.*]] = onnx.Constant dense<-1.49999991E-7> : tensor -// CHECK-NOT: onnx.Mul -// CHECK-NOT: onnx.QuantizeLinear -// CHECK: return -// CHECK-NOT: onnx.DequantizeLinear -func.func @test_removebinary_pattern2a(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.500000e-05> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = onnx.Constant dense<-1.000000e+02> : tensor -%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%6 = "onnx.Div"(%5, %4) : (tensor<1x1x1x128xf32>, tensor) -> tensor<1x1x1x128xf32> -%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %8 : tensor<1x1x1x128xf32> -} -//----- -// 6) const-dq1-div-q-dq. remove->div,q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern2b -// CHECK-NOT: onnx.Div -// CHECK-NOT: onnx.QuantizeLinear -// CHECK: return -// CHECK-NOT: onnx.DequantizeLinear -func.func @test_removebinary_pattern2b(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = onnx.Constant dense<-1.000000e+04> : tensor -%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%6 = "onnx.Div"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %8 : tensor<1x1x1x128xf32> -} -//----- -// 7) const-dq1-div-q-dq. kval=0. remove->div,q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern3c -// CHECK: onnx.Div -// CHECK: onnx.QuantizeLinear -func.func @test_removebinary_pattern3c(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = onnx.Constant dense<0.000000e+00> : tensor -%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%6 = "onnx.Div"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %8 : tensor<1x1x1x128xf32> -} -//----- -// 8) const-dq1-div-q-dq. dst_scale=0. remove->div,q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern3b -// CHECK-NOT: onnx.Div -// CHECK-NOT: onnx.QuantizeLinear -// CHECK: return -// CHECK-NOT: onnx.DequantizeLinear -func.func @test_removebinary_pattern3b(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<0.000000e+00> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = onnx.Constant dense<-1.000000e+04> : tensor -%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%6 = "onnx.Div"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %8 : tensor<1x1x1x128xf32> -} -//----- -// 9) const-dq1-div-q-dq. q!=dq. remove->only div -// CHECK-LABEL: func.func @test_removebinary_pattern4 -// CHECK-NOT: onnx.Div -// CHECK: onnx.QuantizeLinear -func.func @test_removebinary_pattern4(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<0.000000e+00> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = onnx.Constant dense<-1.000000e+04> : tensor -%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%6 = "onnx.Div"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%8 = "onnx.DequantizeLinear"(%7, %1, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %8 : tensor<1x1x1x128xf32> -} -//----- -// 10) const-dq1-div-tanh. remove->none -// CHECK-LABEL: func.func @test_removebinary_pattern5 -// CHECK: onnx.Div -// CHECK: onnx.Tanh -func.func @test_removebinary_pattern5(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<0.000000e+00> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = onnx.Constant dense<-1.000000e+04> : tensor -%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%6 = "onnx.Div"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -%7 = "onnx.Tanh"(%6) : (tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -return %7 : tensor<1x1x1x128xf32> -} -//----- -// 11) dq1-dq2-sub-q-dq1-dq2-div-Q-DQ. multi-use of scale and zp of dq-act before binary op. remove->div, sub -// CHECK-LABEL: func.func @test_removebinary_pattern6 -// CHECK-NOT: onnx.Div -// CHECK-NOT: onnx.Sub -func.func @test_removebinary_pattern6(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = onnx.Constant dense<39664> : tensor -%5 = onnx.Constant dense<2.57987776E-5> : tensor -%6 = "onnx.DequantizeLinear"(%2, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor -%7 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%8 = "onnx.Sub"(%6, %7) {onnx_node_name = "/bert/Sub"} : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -%9 = "onnx.QuantizeLinear"(%8, %1, %0) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%10 = "onnx.DequantizeLinear"(%9, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%11 = "onnx.DequantizeLinear"(%0, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor -%12 = "onnx.Div"(%10, %11) : (tensor<1x1x1x128xf32>, tensor) -> tensor<1x1x1x128xf32> -%13 = "onnx.QuantizeLinear"(%12, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%14 = "onnx.DequantizeLinear"(%13, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %14 : tensor<1x1x1x128xf32> -} -//----- -// -// 12) dq1-dq2(const input, per-axis length-2 on axis=0)-div-q-dq. -// Keep Div and QuantizeLinear present. -// CHECK-LABEL: func.func @test_removebinary_pattern7a -// CHECK-NOT: onnx.Div -// CHECK-NOT: onnx.QuantizeLinear -func.func @test_removebinary_pattern7a(%arg0: tensor<2x1x1x128xui16>) -> tensor<2x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor<2xui16> -%1 = onnx.Constant dense<1.52590219E-5> : tensor<2xf32> -%2 = onnx.Constant dense<65535> : tensor<2xui16> -%3 = onnx.Constant dense<0.152590215> : tensor<2xf32> -%4 = onnx.Constant dense<0> : tensor<2x1x1x1xui16> -%5 = "onnx.DequantizeLinear"(%4, %3, %2) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x1xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x1xf32> -%6 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x128xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xf32> -%7 = "onnx.Div"(%6, %5) : (tensor<2x1x1x128xf32>, tensor<2x1x1x1xf32>) -> tensor<2x1x1x128xf32> -%8 = "onnx.QuantizeLinear"(%7, %3, %2) {axis = 0 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<2x1x1x128xf32>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xui16> -%9 = "onnx.DequantizeLinear"(%8, %3, %2) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x128xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xf32> -return %9 : tensor<2x1x1x128xf32> -} -//----- -// 13) dq1-dq2(const input, per-axis length-2 on axis=0)-div-q-dq. -// Keep Div and QuantizeLinear present. -// CHECK-LABEL: func.func @test_removebinary_pattern7b -// CHECK: onnx.Div -// CHECK: onnx.QuantizeLinear -func.func @test_removebinary_pattern7b(%arg0: tensor<2x1x1x128xui16>) -> tensor<2x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor<2xui16> -%1 = onnx.Constant dense<1.52590219E-5> : tensor<2xf32> -%2 = onnx.Constant dense<[65535, 1]> : tensor<2xui16> -%3 = onnx.Constant dense<0.152590215> : tensor<2xf32> -%4 = onnx.Constant dense<0> : tensor<2x1x1x1xui16> -%5 = "onnx.DequantizeLinear"(%4, %3, %2) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x1xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x1xf32> -%6 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x128xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xf32> -%7 = "onnx.Div"(%6, %5) : (tensor<2x1x1x128xf32>, tensor<2x1x1x1xf32>) -> tensor<2x1x1x128xf32> -%8 = "onnx.QuantizeLinear"(%7, %3, %2) {axis = 0 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<2x1x1x128xf32>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xui16> -%9 = "onnx.DequantizeLinear"(%8, %3, %2) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x128xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xf32> -return %9 : tensor<2x1x1x128xf32> -} diff --git a/test/mlir/onnx/onnx_remove_mul.mlir b/test/mlir/onnx/onnx_remove_mul.mlir deleted file mode 100644 index ad21131746..0000000000 --- a/test/mlir/onnx/onnx_remove_mul.mlir +++ /dev/null @@ -1,254 +0,0 @@ -// RUN: onnx-mlir-opt --dq-binary-q-opt-onnx-to-onnx %s -split-input-file | FileCheck %s - -// 1) dq1-dq2(const input)-mul-q-dq. remove->mul,q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern1a -// CHECK: %[[SCALE:.*]] = onnx.Constant dense<-0.152590215> : tensor -// CHECK-NOT: onnx.Mul -// CHECK-NOT: onnx.QuantizeLinear -// CHECK: return -// CHECK-NOT: onnx.DequantizeLinear -func.func @test_removebinary_pattern1a(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = "onnx.DequantizeLinear"(%0, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor -%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%6 = "onnx.Mul"(%5, %4) : (tensor<1x1x1x128xf32>, tensor) -> tensor<1x1x1x128xf32> -%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %8 : tensor<1x1x1x128xf32> -} -//----- - -// 2) dq1-dq2(const input)-mul-q-dq. remove->mul,q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern1b -// CHECK: %[[SCALE:.*]] = onnx.Constant dense<-0.152590215> : tensor -// CHECK-NOT: onnx.Mul -// CHECK-NOT: onnx.QuantizeLinear -// CHECK: return -// CHECK-NOT: onnx.DequantizeLinear -func.func @test_removebinary_pattern1b(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = "onnx.DequantizeLinear"(%0, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor -%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%6 = "onnx.Mul"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %8 : tensor<1x1x1x128xf32> -} -//----- -// Test fusing a DQ -> Mul -> Q -> DQ pattern. -// The constant input to the Mul is produced by a chain: Constant -> Identity -> DequantizeLinear. -// The pass should look through the Identity op, perform the fusion, and remove the redundant Q->DQ chain. -// 3) dq1-dq2(const input)-mul-q-dq. remove->mul,q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern1c -// CHECK-NOT: onnx.Mul -// CHECK-NOT: onnx.QuantizeLinear -func.func @test_removebinary_pattern1c(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = onnx.Constant dense<0> : tensor -%5 = "onnx.Identity"(%4) : (tensor) -> tensor -%6 = "onnx.DequantizeLinear"(%5, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor -%7 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%8 = "onnx.Mul"(%7, %6) : (tensor<1x1x1x128xf32>, tensor) -> tensor<1x1x1x128xf32> -%9 = "onnx.QuantizeLinear"(%8, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%10 = "onnx.DequantizeLinear"(%9, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %10 : tensor<1x1x1x128xf32> -} -//----- -// 4) dq1-dq2(const input)-mul-q-dq. remove->mul,q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern1d -// CHECK-NOT: onnx.Mul -// CHECK-NOT: onnx.QuantizeLinear -func.func @test_removebinary_pattern1d(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = onnx.Constant dense<0> : tensor -%5 = "onnx.Identity"(%4) : (tensor) -> tensor -%6 = "onnx.DequantizeLinear"(%5, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor -%7 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%8 = "onnx.Mul"(%6, %7) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -%9 = "onnx.QuantizeLinear"(%8, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%10 = "onnx.DequantizeLinear"(%9, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %10 : tensor<1x1x1x128xf32> -} -//----- -// 5) dq1-const-mul-q-dq. remove->mul,q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern2a -// CHECK: %[[SCALE:.*]] = onnx.Constant dense<-0.00152590219> : tensor -// CHECK-NOT: onnx.Mul -// CHECK-NOT: onnx.QuantizeLinear -// CHECK: return -// CHECK-NOT: onnx.DequantizeLinear -func.func @test_removebinary_pattern2a(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = onnx.Constant dense<-1.000000e+02> : tensor -%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%6 = "onnx.Mul"(%5, %4) : (tensor<1x1x1x128xf32>, tensor) -> tensor<1x1x1x128xf32> -%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %8 : tensor<1x1x1x128xf32> -} -//----- -// 6) const-dq1-mul-q-dq. remove->mul,q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern2b -// CHECK-NOT: onnx.Mul -// CHECK-NOT: onnx.QuantizeLinear -// CHECK: return -// CHECK-NOT: onnx.DequantizeLinear -func.func @test_removebinary_pattern2b(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = onnx.Constant dense<-1.000000e+04> : tensor -%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%6 = "onnx.Mul"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %8 : tensor<1x1x1x128xf32> -} -//----- -// 7) const-dq1-mul-q-dq. kval=0. remove->mul,q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern3a -// CHECK-NOT: onnx.Mul -// CHECK-NOT: onnx.QuantizeLinear -// CHECK: return -// CHECK-NOT: onnx.DequantizeLinear -func.func @test_removebinary_pattern3a(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = onnx.Constant dense<0.000000e+00> : tensor -%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%6 = "onnx.Mul"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %8 : tensor<1x1x1x128xf32> -} -//----- -// 8) const-dq1-mul-q-dq. dst_scale=0. remove->mul,q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern3b -// CHECK-NOT: onnx.Mul -// CHECK-NOT: onnx.QuantizeLinear -// CHECK: return -// CHECK-NOT: onnx.DequantizeLinear -func.func @test_removebinary_pattern3b(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<0.000000e+00> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = onnx.Constant dense<-1.000000e+04> : tensor -%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%6 = "onnx.Mul"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %8 : tensor<1x1x1x128xf32> -} -//----- -// 9) const-dq1-mul-q-dq. q!=dq. remove->only mul -// CHECK-LABEL: func.func @test_removebinary_pattern4 -// CHECK-NOT: onnx.Mul -// CHECK: onnx.QuantizeLinear -func.func @test_removebinary_pattern4(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<0.000000e+00> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = onnx.Constant dense<-1.000000e+04> : tensor -%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%6 = "onnx.Mul"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%8 = "onnx.DequantizeLinear"(%7, %1, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %8 : tensor<1x1x1x128xf32> -} -//----- -// 10) const-dq1-mul-tanh. remove->none -// CHECK-LABEL: func.func @test_removebinary_pattern5 -// CHECK: onnx.Mul -// CHECK: onnx.Tanh -func.func @test_removebinary_pattern5(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<0.000000e+00> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = onnx.Constant dense<-1.000000e+04> : tensor -%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%6 = "onnx.Mul"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -%7 = "onnx.Tanh"(%6) : (tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -return %7 : tensor<1x1x1x128xf32> -} -//----- -// 11) dq1-dq2-sub-q-dq1-dq2-mul-Q-DQ. multi-use of scale and zp of dq-act before binary op. remove->mul, sub -// CHECK-LABEL: func.func @test_removebinary_pattern6 -// CHECK-NOT: onnx.Mul -// CHECK-NOT: onnx.Sub -func.func @test_removebinary_pattern6(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = onnx.Constant dense<39664> : tensor -%5 = onnx.Constant dense<2.57987776E-5> : tensor -%6 = "onnx.DequantizeLinear"(%2, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor -%7 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%8 = "onnx.Add"(%6, %7) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -%9 = "onnx.QuantizeLinear"(%8, %1, %0) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%10 = "onnx.DequantizeLinear"(%9, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%11 = "onnx.DequantizeLinear"(%0, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor -%12 = "onnx.Mul"(%10, %11) : (tensor<1x1x1x128xf32>, tensor) -> tensor<1x1x1x128xf32> -%13 = "onnx.QuantizeLinear"(%12, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%14 = "onnx.DequantizeLinear"(%13, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %14 : tensor<1x1x1x128xf32> -} -//----- -// 12) dq1-dq2(const input, per-axis length-2 on axis=0)-mul-q-dq. -// Keep Mul and QuantizeLinear present. -// CHECK-LABEL: func.func @test_removebinary_pattern7a -// CHECK-NOT: onnx.Mul -// CHECK-NOT: onnx.QuantizeLinear -func.func @test_removebinary_pattern7a(%arg0: tensor<2x1x1x128xui16>) -> tensor<2x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor<2xui16> -%1 = onnx.Constant dense<1.52590219E-5> : tensor<2xf32> -%2 = onnx.Constant dense<65535> : tensor<2xui16> -%3 = onnx.Constant dense<0.152590215> : tensor<2xf32> -%4 = onnx.Constant dense<0> : tensor<2x1x1x1xui16> -%5 = "onnx.DequantizeLinear"(%4, %3, %2) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x1xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x1xf32> -%6 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x128xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xf32> -%7 = "onnx.Mul"(%6, %5) : (tensor<2x1x1x128xf32>, tensor<2x1x1x1xf32>) -> tensor<2x1x1x128xf32> -%8 = "onnx.QuantizeLinear"(%7, %3, %2) {axis = 0 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<2x1x1x128xf32>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xui16> -%9 = "onnx.DequantizeLinear"(%8, %3, %2) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x128xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xf32> -return %9 : tensor<2x1x1x128xf32> -} -//----- -// 13) dq1-dq2(const input, per-axis length-2 on axis=0)-mul-q-dq. -// Keep Mul and QuantizeLinear present. -// CHECK-LABEL: func.func @test_removebinary_pattern7b -// CHECK: onnx.Mul -// CHECK: onnx.QuantizeLinear -func.func @test_removebinary_pattern7b(%arg0: tensor<2x1x1x128xui16>) -> tensor<2x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor<2xui16> -%1 = onnx.Constant dense<1.52590219E-5> : tensor<2xf32> -%2 = onnx.Constant dense<[65535, 1]> : tensor<2xui16> -%3 = onnx.Constant dense<0.152590215> : tensor<2xf32> -%4 = onnx.Constant dense<0> : tensor<2x1x1x1xui16> -%5 = "onnx.DequantizeLinear"(%4, %3, %2) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x1xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x1xf32> -%6 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x128xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xf32> -%7 = "onnx.Mul"(%6, %5) : (tensor<2x1x1x128xf32>, tensor<2x1x1x1xf32>) -> tensor<2x1x1x128xf32> -%8 = "onnx.QuantizeLinear"(%7, %3, %2) {axis = 0 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<2x1x1x128xf32>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xui16> -%9 = "onnx.DequantizeLinear"(%8, %3, %2) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x128xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xf32> -return %9 : tensor<2x1x1x128xf32> -} diff --git a/test/mlir/onnx/onnx_remove_sub.mlir b/test/mlir/onnx/onnx_remove_sub.mlir deleted file mode 100644 index e92b1907cd..0000000000 --- a/test/mlir/onnx/onnx_remove_sub.mlir +++ /dev/null @@ -1,257 +0,0 @@ -// RUN: onnx-mlir-opt --dq-binary-q-opt-onnx-to-onnx %s --split-input-file | FileCheck %s - -// 1) dq1-dq2(const input)-sub-q-dq. remove->sub,q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern1a -// CHECK: %[[ZP:.*]] = onnx.Constant dense<65535> : tensor -// CHECK-NOT: onnx.Sub -// CHECK-NOT: onnx.QuantizeLinear -// CHECK: return -// CHECK-NOT: onnx.DequantizeLinear -func.func @test_removebinary_pattern1a(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<39664> : tensor -%4 = onnx.Constant dense<2.57987776E-5> : tensor -%5 = "onnx.DequantizeLinear"(%2, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor -%6 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%7 = "onnx.Sub"(%6, %5) : (tensor<1x1x1x128xf32>, tensor) -> tensor<1x1x1x128xf32> -%8 = "onnx.QuantizeLinear"(%7, %1, %0) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%9 = "onnx.DequantizeLinear"(%8, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %9 : tensor<1x1x1x128xf32> -} - -// ----- -// 2) dq1-dq2(const input)-sub-q-dq. remove->sub,q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern1b -// CHECK: %[[ZP:.*]] = onnx.Constant dense<65535> : tensor -// CHECK-NOT: onnx.Sub -// CHECK-NOT: onnx.QuantizeLinear -// CHECK: return -// CHECK-NOT: onnx.DequantizeLinear -func.func @test_removebinary_pattern1b(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<39664> : tensor -%4 = onnx.Constant dense<2.57987776E-5> : tensor -%5 = "onnx.DequantizeLinear"(%2, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor -%6 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%7 = "onnx.Sub"(%5, %6) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -%8 = "onnx.QuantizeLinear"(%7, %1, %0) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%9 = "onnx.DequantizeLinear"(%8, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %9 : tensor<1x1x1x128xf32> -} - -// ----- -// 3) dq1-dq2(const input)-Sub-q-dq. remove->Sub,q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern1c -// CHECK-NOT: onnx.Sub -// CHECK-NOT: onnx.QuantizeLinear -func.func @test_removebinary_pattern1c(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = onnx.Constant dense<0> : tensor -%5 = "onnx.Identity"(%4) : (tensor) -> tensor -%6 = "onnx.DequantizeLinear"(%5, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor -%7 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%8 = "onnx.Sub"(%7, %6) : (tensor<1x1x1x128xf32>, tensor) -> tensor<1x1x1x128xf32> -%9 = "onnx.QuantizeLinear"(%8, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%10 = "onnx.DequantizeLinear"(%9, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %10 : tensor<1x1x1x128xf32> -} - -// ----- -// 4) dq1-dq2(const input)-Sub-q-dq. remove->Sub,q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern1d -// CHECK-NOT: onnx.Sub -// CHECK-NOT: onnx.QuantizeLinear -func.func @test_removebinary_pattern1d(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = onnx.Constant dense<0> : tensor -%5 = "onnx.Identity"(%4) : (tensor) -> tensor -%6 = "onnx.DequantizeLinear"(%5, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor -%7 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%8 = "onnx.Sub"(%6, %7) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -%9 = "onnx.QuantizeLinear"(%8, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%10 = "onnx.DequantizeLinear"(%9, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %10 : tensor<1x1x1x128xf32> -} - -//----- -// 5) dq1-const-add-q-dq. remove->add, q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern2a -// CHECK: %[[ZP:.*]] = onnx.Constant dense<102> : tensor -// CHECK-NOT: onnx.Add -// CHECK-NOT: onnx.QuantizeLinear -// CHECK: return -// CHECK-NOT: onnx.DequantizeLinear -func.func @test_removebinary_pattern2a(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<101> : tensor -%1 = onnx.Constant dense<1.000000e+00> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = onnx.Constant dense<1.000000e+00> : tensor -%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%6 = "onnx.Sub"(%5, %4) : (tensor<1x1x1x128xf32>, tensor) -> tensor<1x1x1x128xf32> -%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %8 : tensor<1x1x1x128xf32> -} -//----- -// 6) const-dq1-sub-q-dq. remove->sub,q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern2b -// CHECK-NOT: onnx.Sub -// CHECK-NOT: onnx.QuantizeLinear -// CHECK: return -// CHECK-NOT: onnx.DequantizeLinear -func.func @test_removebinary_pattern2b(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = onnx.Constant dense<-1.000000e+04> : tensor -%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%6 = "onnx.Sub"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %8 : tensor<1x1x1x128xf32> -} -//----- -// 7) const-dq1-sub-q-dq. kval=0. remove->sub,q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern3a -// CHECK-NOT: onnx.Sub -// CHECK-NOT: onnx.QuantizeLinear -// CHECK: return -// CHECK-NOT: onnx.DequantizeLinear -func.func @test_removebinary_pattern3a(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = onnx.Constant dense<0.000000e+00> : tensor -%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%6 = "onnx.Sub"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %8 : tensor<1x1x1x128xf32> -} -//----- -// 8) const-dq1-Sub-q-dq. dst_scale=0. remove->Sub,q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern3b -// CHECK: onnx.Sub -// CHECK: onnx.QuantizeLinear -func.func @test_removebinary_pattern3b(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<0.000000e+00> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = onnx.Constant dense<-1.000000e+04> : tensor -%5 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%6 = "onnx.Sub"(%4, %5) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -%7 = "onnx.QuantizeLinear"(%6, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%8 = "onnx.DequantizeLinear"(%7, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %8 : tensor<1x1x1x128xf32> -} -//----- -// 9) dq1-dq2(const input)-sub-q-dq. remove->sub,q-dq. -// CHECK-LABEL: func.func @test_removebinary_pattern4 -// CHECK-NOT: onnx.Sub -// CHECK: onnx.QuantizeLinear -func.func @test_removebinary_pattern4(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<39664> : tensor -%4 = onnx.Constant dense<2.57987776E-5> : tensor -%5 = "onnx.DequantizeLinear"(%2, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor -%6 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%7 = "onnx.Sub"(%5, %6) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -%8 = "onnx.QuantizeLinear"(%7, %1, %0) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%9 = "onnx.DequantizeLinear"(%8, %4, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %9 : tensor<1x1x1x128xf32> -} -//----- -// 10) const-dq1-Sub-tanh. remove->none -// CHECK-LABEL: func.func @test_removebinary_pattern5 -// CHECK: onnx.Sub -// CHECK: onnx.Tanh -func.func @test_removebinary_pattern5(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<39664> : tensor -%4 = onnx.Constant dense<2.57987776E-5> : tensor -%5 = "onnx.DequantizeLinear"(%2, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor -%6 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%7 = "onnx.Sub"(%5, %6) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -%8 = "onnx.Tanh"(%7) : (tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -return %8 : tensor<1x1x1x128xf32> -} -//----- -// 11) dq1-dq2-sub-q-dq1-dq2-mul-Q-DQ. multi-use of scale and zp of dq-act before binary op. remove->mul, sub -// CHECK-LABEL: func.func @test_removebinary_pattern6 -// CHECK-NOT: onnx.Sub -// CHECK-NOT: onnx.Sub -func.func @test_removebinary_pattern6(%arg0: tensor<1x1x1x128xui16>) -> tensor<1x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor -%1 = onnx.Constant dense<1.52590219E-5> : tensor -%2 = onnx.Constant dense<65535> : tensor -%3 = onnx.Constant dense<0.152590215> : tensor -%4 = onnx.Constant dense<39664> : tensor -%5 = onnx.Constant dense<2.57987776E-5> : tensor -%6 = "onnx.DequantizeLinear"(%2, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor -%7 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%8 = "onnx.Mul"(%6, %7) : (tensor, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> -%9 = "onnx.QuantizeLinear"(%8, %1, %0) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%10 = "onnx.DequantizeLinear"(%9, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -%11 = "onnx.DequantizeLinear"(%0, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor -%12 = "onnx.Sub"(%10, %11) : (tensor<1x1x1x128xf32>, tensor) -> tensor<1x1x1x128xf32> -%13 = "onnx.QuantizeLinear"(%12, %3, %2) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x1x128xf32>, tensor, tensor) -> tensor<1x1x1x128xui16> -%14 = "onnx.DequantizeLinear"(%13, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x1x128xui16>, tensor, tensor) -> tensor<1x1x1x128xf32> -return %14 : tensor<1x1x1x128xf32> -} -//----- -// 12) dq1-dq2(const input, per-axis length-2 on axis=0)-mul-q-dq. -// vectors wiht same values -> fusion -// CHECK-LABEL: func.func @test_removebinary_pattern7a -// CHECK-NOT: onnx.Sub -// CHECK-NOT: onnx.QuantizeLinear -func.func @test_removebinary_pattern7a(%arg0: tensor<2x1x1x128xui16>) -> tensor<2x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor<2xui16> -%1 = onnx.Constant dense<1.52590219E-5> : tensor<2xf32> -%2 = onnx.Constant dense<65535> : tensor<2xui16> -%3 = onnx.Constant dense<0.152590215> : tensor<2xf32> -%4 = onnx.Constant dense<0> : tensor<2x1x1x1xui16> -%5 = "onnx.DequantizeLinear"(%4, %3, %2) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x1xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x1xf32> -%6 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x128xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xf32> -%7 = "onnx.Sub"(%6, %5) : (tensor<2x1x1x128xf32>, tensor<2x1x1x1xf32>) -> tensor<2x1x1x128xf32> -%8 = "onnx.QuantizeLinear"(%7, %3, %2) {axis = 0 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<2x1x1x128xf32>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xui16> -%9 = "onnx.DequantizeLinear"(%8, %3, %2) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x128xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xf32> -return %9 : tensor<2x1x1x128xf32> -} -//----- -// -// 13) dq1-dq2(const input, per-axis length-2 on axis=0)-mul-q-dq. -// vectors wiht different values -> no fusion -// CHECK-LABEL: func.func @test_removebinary_pattern7b -// CHECK: onnx.Sub -// CHECK: onnx.QuantizeLinear -func.func @test_removebinary_pattern7b(%arg0: tensor<2x1x1x128xui16>) -> tensor<2x1x1x128xf32> { -%0 = onnx.Constant dense<0> : tensor<2xui16> -%1 = onnx.Constant dense<1.52590219E-5> : tensor<2xf32> -%2 = onnx.Constant dense<[65535, 1]> : tensor<2xui16> -%3 = onnx.Constant dense<0.152590215> : tensor<2xf32> -%4 = onnx.Constant dense<0> : tensor<2x1x1x1xui16> -%5 = "onnx.DequantizeLinear"(%4, %3, %2) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x1xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x1xf32> -%6 = "onnx.DequantizeLinear"(%arg0, %1, %0) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x128xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xf32> -%7 = "onnx.Sub"(%6, %5) : (tensor<2x1x1x128xf32>, tensor<2x1x1x1xf32>) -> tensor<2x1x1x128xf32> -%8 = "onnx.QuantizeLinear"(%7, %3, %2) {axis = 0 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<2x1x1x128xf32>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xui16> -%9 = "onnx.DequantizeLinear"(%8, %3, %2) {axis = 0 : si64, block_size = 0 : si64} : (tensor<2x1x1x128xui16>, tensor<2xf32>, tensor<2xui16>) -> tensor<2x1x1x128xf32> -return %9 : tensor<2x1x1x128xf32> -} From 963ea04bdb0e328ec72883badc58b955796016ed Mon Sep 17 00:00:00 2001 From: sushmita Date: Tue, 11 Nov 2025 00:17:29 +0530 Subject: [PATCH 2/5] cleaned and test added --- src/Dialect/ONNX/Transforms/DQBinaryQOpt.cpp | 132 +++---- test/mlir/onnx/onnx_remove_binary.mlir | 381 +++++++++++++++++++ 2 files changed, 434 insertions(+), 79 deletions(-) create mode 100644 test/mlir/onnx/onnx_remove_binary.mlir diff --git a/src/Dialect/ONNX/Transforms/DQBinaryQOpt.cpp b/src/Dialect/ONNX/Transforms/DQBinaryQOpt.cpp index 8c3f562612..9a4d57fcf5 100644 --- a/src/Dialect/ONNX/Transforms/DQBinaryQOpt.cpp +++ b/src/Dialect/ONNX/Transforms/DQBinaryQOpt.cpp @@ -4,6 +4,7 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/STLExtras.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" @@ -17,7 +18,6 @@ #include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" #include "src/Pass/Passes.hpp" -#include "llvm/ADT/STLExtras.h" #include #include #include @@ -108,8 +108,8 @@ std::optional getScalarTensorValueFromVal(Value value) { return getScalarTensorValue(constOp); } -static mlir::DenseElementsAttr makeScalarDEA( - mlir::ShapedType likeTy, double d) { +static mlir::DenseElementsAttr makeScalarDEA(mlir::ShapedType likeTy, + double d) { using namespace mlir; auto ranked = likeTy.dyn_cast(); @@ -127,7 +127,7 @@ static mlir::DenseElementsAttr makeScalarDEA( llvm::APFloat ap(d); bool loses = false; ap.convert(useFT.getFloatSemantics(), llvm::APFloat::rmNearestTiesToEven, - &loses); + &loses); dv = ap.convertToDouble(); } return DenseElementsAttr::get(ranked, FloatAttr::get(outFT, dv)); @@ -167,7 +167,8 @@ static mlir::DenseElementsAttr makeScalarDEA( } static void updateInitializer(mlir::PatternRewriter &rewriter, - mlir::Operation *targetOp, mlir::Value oldInit, double newScalar) { + mlir::Operation *targetOp, mlir::Value oldInit, + double newScalar) { using namespace mlir; if (!targetOp || !oldInit) @@ -236,9 +237,9 @@ static void updateInitializer(mlir::PatternRewriter &rewriter, // Returns success() iff Q->DQ is *removable* under strict checks. // If doRewrite==true, it also *applies* the rewrite for this DQ (replaces DQ // with Q.x). -static mlir::LogicalResult tryRemoveQThenDQChain( - mlir::PatternRewriter &rewriter, mlir::ONNXDequantizeLinearOp dqOp, - bool doRewrite) { +static mlir::LogicalResult +tryRemoveQThenDQChain(mlir::PatternRewriter &rewriter, + mlir::ONNXDequantizeLinearOp dqOp, bool doRewrite) { using namespace mlir; // Match direct Q -> DQ @@ -286,8 +287,8 @@ static mlir::LogicalResult tryRemoveQThenDQChain( // If doRewrite=false: returns true iff *any* removable DQ user exists (no // mutation). If doRewrite=true : performs removals and returns true iff it // removed at least one DQ. Also erases Q if it becomes dead after removals. -static bool Remove_Q_Plus_DQ( - mlir::PatternRewriter &rewriter, ONNXQuantizeLinearOp qOp, bool doRewrite) { +static bool Remove_Q_Plus_DQ(mlir::PatternRewriter &rewriter, + ONNXQuantizeLinearOp qOp, bool doRewrite) { using namespace mlir; if (!qOp) return false; @@ -322,7 +323,7 @@ static bool isValuePreservingOp(mlir::Operation *op) { if (!op) return false; return isa(op); + mlir::ONNXUnsqueezeOp, mlir::ONNXTransposeOp>(op); } template @@ -353,7 +354,8 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { }; LogicalResult match_qdq(mlir::PatternRewriter &rewriter, MatchState &state, - ONNXDequantizeLinearOp dq1, ONNXDequantizeLinearOp dq2) const { + ONNXDequantizeLinearOp dq1, + ONNXDequantizeLinearOp dq2) const { ONNXDequantizeLinearOp constantDqOp = nullptr; ONNXConstantOp constantSourceOp = nullptr; @@ -374,7 +376,7 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { if (auto intermediateOp = dq1.getX().getDefiningOp()) { if (isValuePreservingOp(intermediateOp)) { if (auto constOp = intermediateOp->getOperand(0) - .getDefiningOp()) { + .getDefiningOp()) { constantDqOp = dq1; state.dequantActivationOfBinOp = dq2; constantSourceOp = constOp; @@ -384,7 +386,7 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { if (auto intermediateOp = dq2.getX().getDefiningOp()) { if (isValuePreservingOp(intermediateOp)) { if (auto constOp = intermediateOp->getOperand(0) - .getDefiningOp()) { + .getDefiningOp()) { constantDqOp = dq2; state.dequantActivationOfBinOp = dq1; constantSourceOp = constOp; @@ -401,7 +403,8 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { { auto scalar_value_opt = getScalarTensorValue(constantSourceOp); if (!scalar_value_opt) { - return rewriter.notifyMatchFailure(constantSourceOp, + return rewriter.notifyMatchFailure( + constantSourceOp, " must be a scalar value or a list of same value"); } Value scaleVal = constantDqOp.getXScale(); @@ -419,7 +422,7 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { } LogicalResult match_binary_op(mlir::PatternRewriter &rewriter, - MatchState &state, BinOp binaryOp) const { + MatchState &state, BinOp binaryOp) const { ONNXConstantOp constantOp = nullptr; Value lhs = binaryOp.getOperand(0); @@ -427,11 +430,6 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { Value out = binaryOp->getResult(0); state.quantOutputOfBinOp = dyn_cast(*out.getUsers().begin()); - // auto qOut = getUniqueQuantUserOrNull(binaryOp->getResult(0)); - // if (!qOut) - // return rewriter.notifyMatchFailure(binaryOp, - // "binary result must have exactly one ONNXQuantizeLinearOp user"); - // state.quantOutputOfBinOp = qOut; // -------- Case A: lhs is DQ, rhs is Constant -------- if (auto dqOp = lhs.getDefiningOp()) { @@ -469,28 +467,9 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { return failure(); } - /*LogicalResult check_needed_values(mlir::PatternRewriter &rewriter, - const MatchState &state, Operation *binaryOp) const { - if (state.kValue == 0.0) { - if (isa(binaryOp)) { - return rewriter.notifyMatchFailure(binaryOp, - "when opType is Div, remove binary op only if k_value is " - "not zero, to avoid ZeroDivisionError"); - } - } - if (state.dstScale == 0.0) { - if (isa(binaryOp)) { - return rewriter.notifyMatchFailure(binaryOp, - "when opType is Add or Sub, remove binary op only if y_scale is " - "not " - "zero, to avoid ZeroDivisionError"); - } - } - return success(); - }*/ - LogicalResult check_needed_values(mlir::PatternRewriter &rewriter, - const MatchState &state, Operation *binaryOp) const { + const MatchState &state, + Operation *binaryOp) const { const bool dstIsDQ = llvm::isa(state.dstNode); const bool dstIsQ = llvm::isa(state.dstNode); @@ -499,12 +478,15 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { // 0 in those cases. if (state.kValue == 0.0) { if (dstIsDQ && llvm::isa(binaryOp)) { - return rewriter.notifyMatchFailure(binaryOp, + + return rewriter.notifyMatchFailure( + binaryOp, "when opType is Div, remove binary op only if k_value is not zero, " "to avoid ZeroDivisionError"); - } - else if (dstIsQ && llvm::isa(binaryOp)) { - return rewriter.notifyMatchFailure(binaryOp, + } else if (dstIsQ && llvm::isa(binaryOp)) { + + return rewriter.notifyMatchFailure( + binaryOp, "when opType is Mul, remove binary op only if k_value is not zero, " "to avoid ZeroDivisionError"); } @@ -513,7 +495,8 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { // k/scale is used for Add/Sub to update zero_point. // Avoid division by zero when dstScale == 0. if (state.dstScale == 0.0 && (llvm::isa(binaryOp))) { - return rewriter.notifyMatchFailure(binaryOp, + return rewriter.notifyMatchFailure( + binaryOp, "when opType is Add or Sub, remove binary op only if scale is not " "zero, to avoid ZeroDivisionError"); } @@ -552,7 +535,7 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { std::is_same_v || std::is_same_v || std::is_same_v, - "Unsupported binary operation type for this pattern"); + "Unsupported binary operation type for this pattern"); return false; } @@ -564,20 +547,8 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { return true; } - static ONNXQuantizeLinearOp getSingleQuantizeUser(Value v) { - ONNXQuantizeLinearOp q = nullptr; - for (Operation *u : v.getUsers()) { - if (auto cand = dyn_cast(u)) { - if (q) - return nullptr; // more than one Quantize user - q = cand; - } - } - return q; - } - - LogicalResult findDestinationNode( - mlir::PatternRewriter &rewriter, MatchState &state, Operation *op) const { + LogicalResult findDestinationNode(mlir::PatternRewriter &rewriter, + MatchState &state, Operation *op) const { auto dq = state.dequantActivationOfBinOp; if (!dq) return rewriter.notifyMatchFailure( @@ -638,8 +609,8 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { } public: - LogicalResult matchAndRewrite( - BinOp op, PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(BinOp op, + PatternRewriter &rewriter) const override { // STEP 1: Match begin: Assuming only one user if (!op->hasOneUse()) { @@ -656,10 +627,10 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { // STEP 2 if (failed(match_binary_op(rewriter, state, op))) { - return rewriter.notifyMatchFailure(op, - " does not match to critieria to remove binary. Remove binary op " - "only if one of the dequantize linear input " - "has const scalar value "); + return rewriter.notifyMatchFailure( + op, " does not match to critieria to remove binary. Remove binary op " + "only if one of the dequantize linear input " + "has const scalar value "); } // STEP 3 @@ -669,6 +640,7 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { // STEP 4 if (failed(check_needed_values(rewriter, state, op))) { + return failure(); } @@ -689,13 +661,13 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { // Update zero-point at DQ.x Value xZp = dqDst.getXZeroPoint(); updateInitializer(rewriter, dqDst.getOperation(), xZp, - static_cast(state.newZp)); + static_cast(state.newZp)); } else if constexpr (std::is_same_v || std::is_same_v) { // Update scale at DQ.x Value xScale = dqDst.getXScale(); - updateInitializer( - rewriter, dqDst.getOperation(), xScale, state.newScale); + updateInitializer(rewriter, dqDst.getOperation(), xScale, + state.newScale); } } else if (auto qDst = llvm::dyn_cast(dst)) { if constexpr (std::is_same_v || @@ -703,13 +675,13 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { // Update zero-point at Q.y Value yZp = qDst.getYZeroPoint(); updateInitializer(rewriter, qDst.getOperation(), yZp, - static_cast(state.newZp)); + static_cast(state.newZp)); } else if constexpr (std::is_same_v || std::is_same_v) { // Update scale at Q.y Value yScale = qDst.getYScale(); - updateInitializer( - rewriter, qDst.getOperation(), yScale, state.newScale); + updateInitializer(rewriter, qDst.getOperation(), yScale, + state.newScale); } } else { return rewriter.notifyMatchFailure( @@ -717,10 +689,10 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { } } - // STEP 6: Remove binary op + // STEP 7: Remove binary op rewriter.replaceOp(op, state.dequantActivationOfBinOp.getResult()); - // STEP 7: Remove Q->DQ chain + // STEP 8: Remove Q->DQ chain ONNXQuantizeLinearOp chainStartQ = nullptr; if (llvm::isa(state.dstNode)) { @@ -749,7 +721,9 @@ struct FoldDQBinaryQPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FoldDQBinaryQPass) - StringRef getArgument() const final { return "dq-binary-q-opt-onnx-to-onnx"; } + StringRef getArgument() const final { + return "dq-binary-q-opt-onnx-to-onnx"; + } StringRef getDescription() const final { return "Fold Add/Sub/Mul/Div through Q/DQ by updating scale/zero_point, " "then remove trivial Q->DQ chains when safe."; @@ -760,7 +734,7 @@ struct FoldDQBinaryQPass RewritePatternSet patterns(&getContext()); patterns .add, FoldBinaryThroughQDQ, - FoldBinaryThroughQDQ, FoldBinaryThroughQDQ>( + FoldBinaryThroughQDQ, FoldBinaryThroughQDQ>( &getContext()); if (failed(applyPatternsGreedily(function, std::move(patterns)))) signalPassFailure(); @@ -772,4 +746,4 @@ namespace onnx_mlir { std::unique_ptr createFoldDQBinaryQPass() { return std::make_unique(); } -} // namespace onnx_mlir \ No newline at end of file +} // namespace onnx_mlir diff --git a/test/mlir/onnx/onnx_remove_binary.mlir b/test/mlir/onnx/onnx_remove_binary.mlir new file mode 100644 index 0000000000..740c557d7b --- /dev/null +++ b/test/mlir/onnx/onnx_remove_binary.mlir @@ -0,0 +1,381 @@ +// RUN: mlir-opt --pass-pipeline="func.func(your-pass-name)" %s | FileCheck %s + + func.func @test_fold_mul_case_b_safe(%arg0: tensor<10x1xf32>) -> tensor<10x1xf32> { + %0 = onnx.Constant dense<0> : tensor + %1 = onnx.Constant dense<5.78499521E-6> : tensor + %2 = onnx.Constant dense<0> : tensor + %3 = onnx.Constant dense<0.00152590231> : tensor + %4 = onnx.Constant dense<65535> : tensor + %5 = onnx.Constant dense<10> : tensor + %6 = onnx.Constant dense<1.000000e-01> : tensor + %7 = "onnx.DequantizeLinear"(%4, %3, %2) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor + %8 = "onnx.QuantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<10x1xf32>, tensor, tensor) -> tensor<10x1xui16> + %9 = "onnx.DequantizeLinear"(%8, %1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<10x1xui16>, tensor, tensor) -> tensor<10x1xf32> + %10 = "onnx.Mul"(%9, %7) : (tensor<10x1xf32>, tensor) -> tensor<10x1xf32> + %11 = "onnx.QuantizeLinear"(%10, %6, %5) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<10x1xf32>, tensor, tensor) -> tensor<10x1xui16> + %12 = "onnx.DequantizeLinear"(%11, %6, %5) {axis = 1 : si64, block_size = 0 : si64} : (tensor<10x1xui16>, tensor, tensor) -> tensor<10x1xf32> + return %12 : tensor<10x1xf32> + } + +// CHECK: %[[ZP:.*]] = onnx.Constant dense<10> : tensor +// CHECK: %[[DQ_SCALE:.*]] = onnx.Constant dense<1.000000e-01> : tensor +// CHECK: %[[NEW_SCALE:.*]] = onnx.Constant dense<9.99999931E-4> : tensor +// CHECK: %[[Q:.*]] = "onnx.QuantizeLinear"(%arg0, %[[NEW_SCALE]], %[[ZP]]) +// CHECK-SAME: : (tensor<10x1xf32>, tensor, tensor) -> tensor<10x1xui16> +// CHECK: %[[DQ:.*]] = "onnx.DequantizeLinear"(%[[Q]], %[[DQ_SCALE]], %[[ZP]]) +// CHECK-SAME: : (tensor<10x1xui16>, tensor, tensor) -> tensor<10x1xf32> +// CHECK: return %[[DQ]] +// CHECK-NOT: "onnx.Mul" + +// ============================================================================ +// ===== CASE A: lhs = DQ, rhs = Const (fold into Q; update Q.y_zero_point) ===== +// ============================================================================ + +func.func @caseA_lhsDQ_rhsConst_foldIntoQ(%arg0: tensor<1x4xf32>) -> tensor<1x4xi8> { + %0 = onnx.Constant dense<5.000000e-01> : tensor + %1 = onnx.Constant dense<0> : tensor + %2 = "onnx.QuantizeLinear"(%arg0, %0, %1) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x4xf32>, tensor, tensor) -> tensor<1x4xi8> + %3 = onnx.Constant dense<5.000000e-01> : tensor + %4 = onnx.Constant dense<0> : tensor + %5 = "onnx.DequantizeLinear"(%2, %3, %4) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x4xi8>, tensor, tensor) -> tensor<1x4xf32> + %6 = onnx.Constant dense<1.000000e+01> : tensor + %7 = "onnx.Add"(%5, %6) : (tensor<1x4xf32>, tensor) -> tensor<1x4xf32> + %8 = onnx.Constant dense<1.000000e-01> : tensor + %9 = onnx.Constant dense<0> : tensor + %10 = "onnx.QuantizeLinear"(%7, %8, %9) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x4xf32>, tensor, tensor) -> tensor<1x4xi8> + return %10 : tensor<1x4xi8> + } + +// CHECK-LABEL: func.func @caseA_lhsDQ_rhsConst_foldIntoQ +// CHECK: %[[S:.*]] = onnx.Constant dense<1.000000e-01> : tensor +// CHECK: %[[ZP:.*]] = onnx.Constant dense<99> : tensor +// CHECK: %[[Q:.*]] = "onnx.QuantizeLinear"(%arg0, %[[S]], %[[ZP]]) +// CHECK-SAME: : (tensor<1x4xf32>, tensor, tensor) -> tensor<1x4xi8> +// CHECK: return %[[Q]] : tensor<1x4xi8> + +// ============================================================================ +// ===== CASE A-REV: rhs = DQ, lhs = Const (fold into Q; update Q.y_zero_point) ===== +// ============================================================================ + +func.func @caseA_rev_rhsDQ_lhsConst_foldIntoQ(%arg0: tensor<1x4xf32>) -> tensor<1x4xi8> { + %0 = onnx.Constant dense<5.000000e-01> : tensor + %1 = onnx.Constant dense<0> : tensor + %2 = "onnx.QuantizeLinear"(%arg0, %0, %1) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x4xf32>, tensor, tensor) -> tensor<1x4xi8> + %3 = onnx.Constant dense<5.000000e-01> : tensor + %4 = onnx.Constant dense<0> : tensor + %5 = "onnx.DequantizeLinear"(%2, %3, %4) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x4xi8>, tensor, tensor) -> tensor<1x4xf32> + %6 = onnx.Constant dense<1.000000e+01> : tensor + %7 = "onnx.Add"(%6, %5) : (tensor, tensor<1x4xf32>) -> tensor<1x4xf32> + %8 = onnx.Constant dense<1.000000e-01> : tensor + %9 = onnx.Constant dense<0> : tensor + %10 = "onnx.QuantizeLinear"(%7, %8, %9) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x4xf32>, tensor, tensor) -> tensor<1x4xi8> + return %10 : tensor<1x4xi8> + } + +// CHECK-LABEL: func.func @caseA_rev_rhsDQ_lhsConst_foldIntoQ +// CHECK: %[[S:.*]] = onnx.Constant dense<1.000000e-01> : tensor +// CHECK: %[[ZP:.*]] = onnx.Constant dense<99> : tensor +// CHECK: %[[Q:.*]] = "onnx.QuantizeLinear"(%arg0, %[[S]], %[[ZP]]) +// CHECK-SAME: : (tensor<1x4xf32>, tensor, tensor) -> tensor<1x4xi8> +// CHECK: return %[[Q]] : tensor<1x4xi8> + +// ============================================================================ +// ===== CASE B: both inputs are DQ; constant via dq1 (fold into Q) ===== +// ============================================================================ + +func.func @caseB_bothDQ_constViaDQ1_foldIntoQ(%arg0: tensor<1x4xf32>) -> tensor<1x4xi8> { + %0 = onnx.Constant dense<5.000000e-01> : tensor + %1 = onnx.Constant dense<0> : tensor + %2 = "onnx.QuantizeLinear"(%arg0, %0, %1) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x4xf32>, tensor, tensor) -> tensor<1x4xi8> + %3 = "onnx.DequantizeLinear"(%2, %0, %1) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x4xi8>, tensor, tensor) -> tensor<1x4xf32> + %4 = onnx.Constant dense<10> : tensor + %5 = onnx.Constant dense<5.000000e+00> : tensor + %6 = onnx.Constant dense<0> : tensor + %7 = "onnx.DequantizeLinear"(%4, %5, %6) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor + %8 = "onnx.Add"(%3, %7) : (tensor<1x4xf32>, tensor) -> tensor<1x4xf32> + %9 = onnx.Constant dense<5.000000e-01> : tensor + %10 = onnx.Constant dense<0> : tensor + %11 = "onnx.QuantizeLinear"(%8, %9, %10) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x4xf32>, tensor, tensor) -> tensor<1x4xi8> + return %11 : tensor<1x4xi8> + } +// CHECK-LABEL: func.func @caseB_bothDQ_constViaDQ1_foldIntoQ +// CHECK: %[[S:.*]] = onnx.Constant dense<5.000000e-01> : tensor +// CHECK: %[[ZP:.*]] = onnx.Constant dense<100> : tensor +// CHECK: %[[Q:.*]] = "onnx.QuantizeLinear"(%arg0, %[[S]], %[[ZP]]) +// CHECK-SAME: : (tensor<1x4xf32>, tensor, tensor) -> tensor<1x4xi8> +// CHECK: return %[[Q]] : tensor<1x4xi8> + +// ============================================================================ +// ===== CASE B with value-preserving link on constant side: Reshape(const_q) → DQ ===== +// ============================================================================ + + func.func @caseB_constViaReshape_foldIntoQ(%arg0: tensor<1x4xf32>) -> tensor<1x4xi8> { + %0 = onnx.Constant dense<1.000000e+00> : tensor + %1 = onnx.Constant dense<0> : tensor + %2 = "onnx.QuantizeLinear"(%arg0, %0, %1) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x4xf32>, tensor, tensor) -> tensor<1x4xi8> + %3 = "onnx.DequantizeLinear"(%2, %0, %1) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x4xi8>, tensor, tensor) -> tensor<1x4xf32> + %4 = onnx.Constant dense<25> : tensor + %5 = onnx.Constant dense<> : tensor<0xi64> + %6 = "onnx.Reshape"(%4, %5) {allowzero = 0 : si64} : (tensor, tensor<0xi64>) -> tensor + %7 = onnx.Constant dense<4.000000e+00> : tensor + %8 = onnx.Constant dense<0> : tensor + %9 = "onnx.DequantizeLinear"(%6, %7, %8) {axis = 1 : si64, block_size = 0 : si64} : (tensor, tensor, tensor) -> tensor + %10 = "onnx.Add"(%3, %9) : (tensor<1x4xf32>, tensor) -> tensor<1x4xf32> + %11 = onnx.Constant dense<1.000000e+00> : tensor + %12 = onnx.Constant dense<0> : tensor + %13 = "onnx.QuantizeLinear"(%10, %11, %12) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x4xf32>, tensor, tensor) -> tensor<1x4xi8> + return %13 : tensor<1x4xi8> + } + +// CHECK-LABEL: func.func @caseB_constViaReshape_foldIntoQ +// CHECK: %[[SCALE:.*]] = onnx.Constant dense<1.000000e+00> : tensor +// CHECK: %[[ZP:.*]] = onnx.Constant dense<100> : tensor +// CHECK: %[[Q:.*]] = "onnx.QuantizeLinear"(%arg0, %[[SCALE]], %[[ZP]]) +// CHECK-SAME: : (tensor<1x4xf32>, tensor, tensor) -> tensor<1x4xi8> +// CHECK: return %[[Q]] : tensor<1x4xi8> +// CHECK-NOT: onnx.Add +// CHECK-NOT: onnx.DequantizeLinear +// CHECK-NOT: onnx.Reshape + +// ============================================================================ +// ===== BRANCH-BEFORE: Q1 has another user (fold into DQ; update DQ.x_zero_point) ===== +// ============================================================================ + + func.func @branchBefore_foldIntoDQ(%arg0: tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor<1x4xi8>) { + %0 = onnx.Constant dense<5.000000e-01> : tensor + %1 = onnx.Constant dense<0> : tensor + %2 = "onnx.QuantizeLinear"(%arg0, %0, %1) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x4xf32>, tensor, tensor) -> tensor<1x4xi8> + %3 = "onnx.Abs"(%2) : (tensor<1x4xi8>) -> tensor<1x4xi8> + %4 = onnx.Constant dense<1.000000e-01> : tensor + %5 = onnx.Constant dense<0> : tensor + %6 = "onnx.DequantizeLinear"(%2, %4, %5) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x4xi8>, tensor, tensor) -> tensor<1x4xf32> + %7 = onnx.Constant dense<1.000000e+01> : tensor + %8 = "onnx.Add"(%6, %7) : (tensor<1x4xf32>, tensor) -> tensor<1x4xf32> + %9 = onnx.Constant dense<2.000000e-01> : tensor + %10 = onnx.Constant dense<0> : tensor + %11 = "onnx.QuantizeLinear"(%8, %9, %10) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x4xf32>, tensor, tensor) -> tensor<1x4xi8> + %12 = "onnx.DequantizeLinear"(%11, %9, %10) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x4xi8>, tensor, tensor) -> tensor<1x4xf32> + return %12, %3 : tensor<1x4xf32>, tensor<1x4xi8> + } + +// CHECK-LABEL: func.func @branchBefore_foldIntoDQ +// CHECK: %[[S_DQ:.*]] = onnx.Constant dense<1.000000e-01> : tensor +// CHECK: %[[S_Q:.*]] = onnx.Constant dense<5.000000e-01> : tensor +// CHECK: %[[ZP:.*]] = onnx.Constant dense<0> : tensor +// CHECK: %[[Q:.*]] = "onnx.QuantizeLinear"(%arg0, %[[S_Q]], %[[ZP]]) +// CHECK-SAME: : (tensor<1x4xf32>, tensor, tensor) -> tensor<1x4xi8> +// CHECK: %[[ABS:.*]] = "onnx.Abs"(%[[Q]]) +// CHECK-SAME: : (tensor<1x4xi8>) -> tensor<1x4xi8> +// CHECK: %[[DQ:.*]] = "onnx.DequantizeLinear"(%[[Q]], %[[S_DQ]], %[[ZP]]) +// CHECK-SAME: : (tensor<1x4xi8>, tensor, tensor) -> tensor<1x4xf32> +// CHECK: return %[[DQ]], %[[ABS]] : tensor<1x4xf32>, tensor<1x4xi8> + + +// ============================================================================ +// k_value == 0 and (dst is DequantizeLinear) with a Div +// Expectation: DO NOT fold (would require scale_new = scale / k, div-by-zero) +// Reason: k_value = (const_q - zp) * scale_const = (7 - 7) * 0.5 = 0 +// ============================================================================ + +func.func @guard_div_into_dq_k_zero(%arg0: tensor<1x4xf32>) -> tensor<1x4xf32> { + // Activation path: Q -> DQ + %s_act = onnx.Constant dense<5.000000e-01> : tensor + %zp_act = onnx.Constant dense<0> : tensor + %q_act = "onnx.QuantizeLinear"(%arg0, %s_act, %zp_act) : (tensor<1x4xf32>, tensor, tensor) -> tensor<1x4xi8> + %dq_act = "onnx.DequantizeLinear"(%q_act, %s_act, %zp_act) : (tensor<1x4xi8>, tensor, tensor) -> tensor<1x4xf32> + + // Constant path into DQ with k_value == 0 + // const_q = 7, zp = 7, scale = 0.5 => k = (7-7)*0.5 = 0 + %const_q = onnx.Constant dense<7> : tensor + %s_c = onnx.Constant dense<5.000000e-01> : tensor + %zp_c = onnx.Constant dense<7> : tensor + %dq_c = "onnx.DequantizeLinear"(%const_q, %s_c, %zp_c) : (tensor, tensor, tensor) -> tensor + + // Binary op is Div. Destination for a fold here would be the upstream DQ (%dq_act). + %div = "onnx.Div"(%dq_act, %dq_c) : (tensor<1x4xf32>, tensor) -> tensor<1x4xf32> + + return %div : tensor<1x4xf32> +} + + // CHECK-LABEL: @guard_div_into_dq_k_zero + // CHECK: "onnx.Div"( + // (No folding → Div must remain present.) + +// ============================================================================ +// k_value == 0 and (dst is QuantizeLinear) with a Mul +// ============================================================================ + +func.func @test_kval_0_dst_q_mul(%arg0: tensor<10x1xf32>) -> tensor<10x1xf32> { + %0 = onnx.Constant dense<0> : tensor + %1 = onnx.Constant dense<5.78499521E-6> : tensor + %2 = onnx.Constant dense<0> : tensor + %3 = onnx.Constant dense<0.00152590231> : tensor + %4 = onnx.Constant dense<0> : tensor + %5 = onnx.Constant dense<10> : tensor + %6 = onnx.Constant dense<1.000000e-01> : tensor + %7 = "onnx.DequantizeLinear"(%4, %3, %2) {axis = 1 : si64, block_size = 0 : si64} + : (tensor, tensor, tensor) -> tensor + %8 = "onnx.QuantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} + : (tensor<10x1xf32>, tensor, tensor) -> tensor<10x1xui16> + %9 = "onnx.DequantizeLinear"(%8, %1, %0) {axis = 1 : si64, block_size = 0 : si64} + : (tensor<10x1xui16>, tensor, tensor) -> tensor<10x1xf32> + %10 = "onnx.Mul"(%9, %7) : (tensor<10x1xf32>, tensor) -> tensor<10x1xf32> + %11 = "onnx.QuantizeLinear"(%10, %6, %5) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} + : (tensor<10x1xf32>, tensor, tensor) -> tensor<10x1xui16> + %12 = "onnx.DequantizeLinear"(%11, %6, %5) {axis = 1 : si64, block_size = 0 : si64} + : (tensor<10x1xui16>, tensor, tensor) -> tensor<10x1xf32> + + return %12 : tensor<10x1xf32> +} + +// CHECK-LABEL: func.func @test_kval_0_dst_q_mul( +// CHECK-SAME: %arg0: tensor<10x1xf32>) -> tensor<10x1xf32> +// CHECK: %[[ZP0:.*]] = onnx.Constant dense<0> : tensor +// CHECK: %[[S_ACT:.*]] = onnx.Constant dense<5.78499521E-6> : tensor +// CHECK: %[[S_K:.*]] = onnx.Constant dense<0.00152590231> : tensor +// CHECK: %[[ZP_OUT:.*]] = onnx.Constant dense<10> : tensor +// CHECK: %[[S_OUT:.*]] = onnx.Constant dense<1.000000e-01> : tensor +// CHECK: %[[DQK:.*]] = "onnx.DequantizeLinear"(%[[ZP0]], %[[S_K]], %[[ZP0]]) +// CHECK-SAME: : (tensor, tensor, tensor) -> tensor +// CHECK: %[[QACT:.*]] = "onnx.QuantizeLinear"(%arg0, %[[S_ACT]], %[[ZP0]]) +// CHECK-SAME: : (tensor<10x1xf32>, tensor, tensor) -> tensor<10x1xui16> +// CHECK: %[[DQACT:.*]] = "onnx.DequantizeLinear"(%[[QACT]], %[[S_ACT]], %[[ZP0]]) +// CHECK-SAME: : (tensor<10x1xui16>, tensor, tensor) -> tensor<10x1xf32> +// CHECK: %[[MUL:.*]] = "onnx.Mul"(%[[DQACT]], %[[DQK]]) +// CHECK-SAME: : (tensor<10x1xf32>, tensor) -> tensor<10x1xf32> +// CHECK: %[[QOUT:.*]] = "onnx.QuantizeLinear"(%[[MUL]], %[[S_OUT]], %[[ZP_OUT]]) +// CHECK-SAME: : (tensor<10x1xf32>, tensor, tensor) -> tensor<10x1xui16> +// CHECK: %[[DQOUT:.*]] = "onnx.DequantizeLinear"(%[[QOUT]], %[[S_OUT]], %[[ZP_OUT]]) +// CHECK-SAME: : (tensor<10x1xui16>, tensor, tensor) -> tensor<10x1xf32> +// CHECK: return %[[DQOUT]] : tensor<10x1xf32> + +// ============================================================================ +// k_value == 0 and (dst is QuantizeLinear) with a Div +// ============================================================================ + +func.func @test_kval_0_dst_q_div(%arg0: tensor<10x1xf32>) -> tensor<10x1xf32> { + %0 = onnx.Constant dense<0> : tensor + %1 = onnx.Constant dense<5.78499521E-6> : tensor + %2 = onnx.Constant dense<0> : tensor + %3 = onnx.Constant dense<0.00152590231> : tensor + %4 = onnx.Constant dense<0> : tensor + %5 = onnx.Constant dense<10> : tensor + %6 = onnx.Constant dense<1.000000e-01> : tensor + %7 = "onnx.DequantizeLinear"(%4, %3, %2) {axis = 1 : si64, block_size = 0 : si64} + : (tensor, tensor, tensor) -> tensor + %8 = "onnx.QuantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} + : (tensor<10x1xf32>, tensor, tensor) -> tensor<10x1xui16> + %9 = "onnx.DequantizeLinear"(%8, %1, %0) {axis = 1 : si64, block_size = 0 : si64} + : (tensor<10x1xui16>, tensor, tensor) -> tensor<10x1xf32> + %10 = "onnx.Div"(%9, %7) : (tensor<10x1xf32>, tensor) -> tensor<10x1xf32> + %11 = "onnx.QuantizeLinear"(%10, %6, %5) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} + : (tensor<10x1xf32>, tensor, tensor) -> tensor<10x1xui16> + %12 = "onnx.DequantizeLinear"(%11, %6, %5) {axis = 1 : si64, block_size = 0 : si64} + : (tensor<10x1xui16>, tensor, tensor) -> tensor<10x1xf32> + + return %12 : tensor<10x1xf32> +} + +// CHECK-LABEL: func.func @test_kval_0_dst_q_div( +// CHECK-SAME: %arg0: tensor<10x1xf32>) -> tensor<10x1xf32> +// CHECK: %[[ZP:.*]] = onnx.Constant dense<10> : tensor +// CHECK: %[[S_DQ:.*]] = onnx.Constant dense<1.000000e-01> : tensor +// CHECK: %[[S_Q:.*]] = onnx.Constant dense<0.000000e+00> : tensor +// CHECK: %[[Q:.*]] = "onnx.QuantizeLinear"(%arg0, %[[S_Q]], %[[ZP]]) +// CHECK-SAME: {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} +// CHECK-SAME: : (tensor<10x1xf32>, tensor, tensor) -> tensor<10x1xui16> +// CHECK: %[[DQ:.*]] = "onnx.DequantizeLinear"(%[[Q]], %[[S_DQ]], %[[ZP]]) +// CHECK-SAME: {axis = 1 : si64, block_size = 0 : si64} +// CHECK-SAME: : (tensor<10x1xui16>, tensor, tensor) -> tensor<10x1xf32> +// CHECK-NOT: "onnx.Div" +// CHECK: return %[[DQ]] : tensor<10x1xf32> + +// ============================================================================ +// Test A: Fold happened into DQ → chainStartQ = Quantize AFTER the BinOp +// Expect: the Q→DQ pair AFTER the BinOp is removed by Remove_Q_Plus_DQ. +// ============================================================================ + +func.func @cleanup_qdq_after_binop_folded_into_dq(%arg0: tensor<4xf32>) -> tensor<4xf32> { + // Activation path: Q_act -> DQ_act + %s_act = onnx.Constant dense<5.000000e-01> : tensor + %zp_act = onnx.Constant dense<0> : tensor + %q_act = "onnx.QuantizeLinear"(%arg0, %s_act, %zp_act) + : (tensor<4xf32>, tensor, tensor) -> tensor<4xi8> + %dq_act = "onnx.DequantizeLinear"(%q_act, %s_act, %zp_act) + : (tensor<4xi8>, tensor, tensor) -> tensor<4xf32> + %c_q = onnx.Constant dense<10> : tensor + %c_s = onnx.Constant dense<5.000000e+00> : tensor + %c_zp = onnx.Constant dense<0> : tensor + %dq_c = "onnx.DequantizeLinear"(%c_q, %c_s, %c_zp) + : (tensor, tensor, tensor) -> tensor + %add = "onnx.Add"(%dq_act, %dq_c) : (tensor<4xf32>, tensor) -> tensor<4xf32> + %s_out = onnx.Constant dense<1.000000e-01> : tensor + %zp_out = onnx.Constant dense<0> : tensor + %q_out = "onnx.QuantizeLinear"(%add, %s_out, %zp_out) + : (tensor<4xf32>, tensor, tensor) -> tensor<4xi8> + + %dq_out = "onnx.DequantizeLinear"(%q_out, %s_out, %zp_out) + : (tensor<4xi8>, tensor, tensor) -> tensor<4xf32> + return %dq_out : tensor<4xf32> +} + +// CHECK-LABEL: func.func @cleanup_qdq_after_binop_folded_into_dq( +// CHECK-SAME: %arg0: tensor<4xf32>) -> tensor<4xf32> +// CHECK: %[[S:.*]] = onnx.Constant dense<1.000000e-01> : tensor +// CHECK: %[[ZP_DQ:.*]] = onnx.Constant dense<0> : tensor +// CHECK: %[[ZP_Q:.*]] = onnx.Constant dense<-1> : tensor +// CHECK: %[[Q:.*]] = "onnx.QuantizeLinear"(%arg0, %[[S]], %[[ZP_Q]]) +// CHECK-SAME: {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} +// CHECK-SAME: : (tensor<4xf32>, tensor, tensor) -> tensor<4xi8> +// CHECK: %[[DQ:.*]] = "onnx.DequantizeLinear"(%[[Q]], %[[S]], %[[ZP_DQ]]) +// CHECK-SAME: {axis = 1 : si64, block_size = 0 : si64} +// CHECK-SAME: : (tensor<4xi8>, tensor, tensor) -> tensor<4xf32> +// CHECK-NOT: "onnx.Add" +// CHECK: return %[[DQ]] : tensor<4xf32> + + +// ============================================================================ +// Test B: Fold happened into Q → chainStartQ = Quantize feeding DQ_act.x +// Expect: the UPSTREAM activation Q→DQ pair is removed by Remove_Q_Plus_DQ. +// ============================================================================ + +func.func @cleanup_qdq_activation_pair_folded_into_q(%arg0: tensor<4xf32>) -> tensor<4xf32> { + // Activation path in fp, then (Q_act -> DQ_act) feeding the BinOp: + %s_act = onnx.Constant dense<2.500000e-01> : tensor + %zp_act = onnx.Constant dense<0> : tensor + %q_act = "onnx.QuantizeLinear"(%arg0, %s_act, %zp_act) + : (tensor<4xf32>, tensor, tensor) -> tensor<4xi8> + %dq_act = "onnx.DequantizeLinear"(%q_act, %s_act, %zp_act) + : (tensor<4xi8>, tensor, tensor) -> tensor<4xf32> + %c_q = onnx.Constant dense<4> : tensor + %c_s = onnx.Constant dense<1.000000e+00> : tensor + %c_zp = onnx.Constant dense<0> : tensor + %dq_c = "onnx.DequantizeLinear"(%c_q, %c_s, %c_zp) + : (tensor, tensor, tensor) -> tensor + %mul = "onnx.Mul"(%dq_act, %dq_c) : (tensor<4xf32>, tensor) -> tensor<4xf32> + %s_out = onnx.Constant dense<1.250000e-01> : tensor + %zp_out = onnx.Constant dense<0> : tensor + %q_out2 = "onnx.QuantizeLinear"(%mul, %s_out, %zp_out) + : (tensor<4xf32>, tensor, tensor) -> tensor<4xi8> + %dq_out2 = "onnx.DequantizeLinear"(%q_out2, %s_out, %zp_out) + : (tensor<4xi8>, tensor, tensor) -> tensor<4xf32> + + return %dq_out2 : tensor<4xf32> +} + +// CHECK-LABEL: func.func @cleanup_qdq_activation_pair_folded_into_q( +// CHECK-SAME: %arg0: tensor<4xf32>) -> tensor<4xf32> +// CHECK: %[[S_DQ:.*]] = onnx.Constant dense<1.250000e-01> : tensor +// CHECK: %[[ZP:.*]] = onnx.Constant dense<0> : tensor +// CHECK: %[[S_Q:.*]] = onnx.Constant dense<3.125000e-02> : tensor +// CHECK: %[[Q:.*]] = "onnx.QuantizeLinear"(%arg0, %[[S_Q]], %[[ZP]]) +// CHECK-SAME: {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} +// CHECK-SAME: : (tensor<4xf32>, tensor, tensor) -> tensor<4xi8> +// CHECK: %[[DQ:.*]] = "onnx.DequantizeLinear"(%[[Q]], %[[S_DQ]], %[[ZP]]) +// CHECK-SAME: {axis = 1 : si64, block_size = 0 : si64} +// CHECK-SAME: : (tensor<4xi8>, tensor, tensor) -> tensor<4xf32> +// CHECK-NOT: "onnx.Add" +// CHECK-NOT: "onnx.Mul" +// CHECK-NOT: "onnx.Div" +// CHECK-NOT: "onnx.Sub +// CHECK: return %[[DQ]] : tensor<4xf32> \ No newline at end of file From 30d56b64f1726c9e7327dfb8740c367f52f5bdb1 Mon Sep 17 00:00:00 2001 From: sushmita Date: Tue, 11 Nov 2025 08:59:35 +0530 Subject: [PATCH 3/5] run command corrected --- test/mlir/onnx/onnx_remove_binary.mlir | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/mlir/onnx/onnx_remove_binary.mlir b/test/mlir/onnx/onnx_remove_binary.mlir index 740c557d7b..fac2c9c2df 100644 --- a/test/mlir/onnx/onnx_remove_binary.mlir +++ b/test/mlir/onnx/onnx_remove_binary.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt --pass-pipeline="func.func(your-pass-name)" %s | FileCheck %s +// RUN: onnx-mlir-opt --dq-binary-q-opt-onnx-to-onnx %s -split-input-file | FileCheck %s + func.func @test_fold_mul_case_b_safe(%arg0: tensor<10x1xf32>) -> tensor<10x1xf32> { %0 = onnx.Constant dense<0> : tensor From e80a7318161bfcd12b9a7300666b8d30e07c525e Mon Sep 17 00:00:00 2001 From: sushmita Date: Tue, 11 Nov 2025 09:14:05 +0530 Subject: [PATCH 4/5] apply clang-format --- src/Dialect/ONNX/Transforms/DQBinaryQOpt.cpp | 79 +++++++++----------- 1 file changed, 35 insertions(+), 44 deletions(-) diff --git a/src/Dialect/ONNX/Transforms/DQBinaryQOpt.cpp b/src/Dialect/ONNX/Transforms/DQBinaryQOpt.cpp index 9a4d57fcf5..0c35018b7d 100644 --- a/src/Dialect/ONNX/Transforms/DQBinaryQOpt.cpp +++ b/src/Dialect/ONNX/Transforms/DQBinaryQOpt.cpp @@ -4,7 +4,6 @@ // //===----------------------------------------------------------------------===// -#include "llvm/ADT/STLExtras.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" @@ -18,6 +17,7 @@ #include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" #include "src/Pass/Passes.hpp" +#include "llvm/ADT/STLExtras.h" #include #include #include @@ -108,8 +108,8 @@ std::optional getScalarTensorValueFromVal(Value value) { return getScalarTensorValue(constOp); } -static mlir::DenseElementsAttr makeScalarDEA(mlir::ShapedType likeTy, - double d) { +static mlir::DenseElementsAttr makeScalarDEA( + mlir::ShapedType likeTy, double d) { using namespace mlir; auto ranked = likeTy.dyn_cast(); @@ -127,7 +127,7 @@ static mlir::DenseElementsAttr makeScalarDEA(mlir::ShapedType likeTy, llvm::APFloat ap(d); bool loses = false; ap.convert(useFT.getFloatSemantics(), llvm::APFloat::rmNearestTiesToEven, - &loses); + &loses); dv = ap.convertToDouble(); } return DenseElementsAttr::get(ranked, FloatAttr::get(outFT, dv)); @@ -167,8 +167,7 @@ static mlir::DenseElementsAttr makeScalarDEA(mlir::ShapedType likeTy, } static void updateInitializer(mlir::PatternRewriter &rewriter, - mlir::Operation *targetOp, mlir::Value oldInit, - double newScalar) { + mlir::Operation *targetOp, mlir::Value oldInit, double newScalar) { using namespace mlir; if (!targetOp || !oldInit) @@ -237,9 +236,9 @@ static void updateInitializer(mlir::PatternRewriter &rewriter, // Returns success() iff Q->DQ is *removable* under strict checks. // If doRewrite==true, it also *applies* the rewrite for this DQ (replaces DQ // with Q.x). -static mlir::LogicalResult -tryRemoveQThenDQChain(mlir::PatternRewriter &rewriter, - mlir::ONNXDequantizeLinearOp dqOp, bool doRewrite) { +static mlir::LogicalResult tryRemoveQThenDQChain( + mlir::PatternRewriter &rewriter, mlir::ONNXDequantizeLinearOp dqOp, + bool doRewrite) { using namespace mlir; // Match direct Q -> DQ @@ -287,8 +286,8 @@ tryRemoveQThenDQChain(mlir::PatternRewriter &rewriter, // If doRewrite=false: returns true iff *any* removable DQ user exists (no // mutation). If doRewrite=true : performs removals and returns true iff it // removed at least one DQ. Also erases Q if it becomes dead after removals. -static bool Remove_Q_Plus_DQ(mlir::PatternRewriter &rewriter, - ONNXQuantizeLinearOp qOp, bool doRewrite) { +static bool Remove_Q_Plus_DQ( + mlir::PatternRewriter &rewriter, ONNXQuantizeLinearOp qOp, bool doRewrite) { using namespace mlir; if (!qOp) return false; @@ -323,7 +322,7 @@ static bool isValuePreservingOp(mlir::Operation *op) { if (!op) return false; return isa(op); + mlir::ONNXUnsqueezeOp, mlir::ONNXTransposeOp>(op); } template @@ -354,8 +353,7 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { }; LogicalResult match_qdq(mlir::PatternRewriter &rewriter, MatchState &state, - ONNXDequantizeLinearOp dq1, - ONNXDequantizeLinearOp dq2) const { + ONNXDequantizeLinearOp dq1, ONNXDequantizeLinearOp dq2) const { ONNXDequantizeLinearOp constantDqOp = nullptr; ONNXConstantOp constantSourceOp = nullptr; @@ -403,8 +401,7 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { { auto scalar_value_opt = getScalarTensorValue(constantSourceOp); if (!scalar_value_opt) { - return rewriter.notifyMatchFailure( - constantSourceOp, + return rewriter.notifyMatchFailure(constantSourceOp, " must be a scalar value or a list of same value"); } Value scaleVal = constantDqOp.getXScale(); @@ -422,7 +419,7 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { } LogicalResult match_binary_op(mlir::PatternRewriter &rewriter, - MatchState &state, BinOp binaryOp) const { + MatchState &state, BinOp binaryOp) const { ONNXConstantOp constantOp = nullptr; Value lhs = binaryOp.getOperand(0); @@ -468,8 +465,7 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { } LogicalResult check_needed_values(mlir::PatternRewriter &rewriter, - const MatchState &state, - Operation *binaryOp) const { + const MatchState &state, Operation *binaryOp) const { const bool dstIsDQ = llvm::isa(state.dstNode); const bool dstIsQ = llvm::isa(state.dstNode); @@ -479,14 +475,12 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { if (state.kValue == 0.0) { if (dstIsDQ && llvm::isa(binaryOp)) { - return rewriter.notifyMatchFailure( - binaryOp, + return rewriter.notifyMatchFailure(binaryOp, "when opType is Div, remove binary op only if k_value is not zero, " "to avoid ZeroDivisionError"); } else if (dstIsQ && llvm::isa(binaryOp)) { - return rewriter.notifyMatchFailure( - binaryOp, + return rewriter.notifyMatchFailure(binaryOp, "when opType is Mul, remove binary op only if k_value is not zero, " "to avoid ZeroDivisionError"); } @@ -495,8 +489,7 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { // k/scale is used for Add/Sub to update zero_point. // Avoid division by zero when dstScale == 0. if (state.dstScale == 0.0 && (llvm::isa(binaryOp))) { - return rewriter.notifyMatchFailure( - binaryOp, + return rewriter.notifyMatchFailure(binaryOp, "when opType is Add or Sub, remove binary op only if scale is not " "zero, to avoid ZeroDivisionError"); } @@ -535,7 +528,7 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { std::is_same_v || std::is_same_v || std::is_same_v, - "Unsupported binary operation type for this pattern"); + "Unsupported binary operation type for this pattern"); return false; } @@ -547,8 +540,8 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { return true; } - LogicalResult findDestinationNode(mlir::PatternRewriter &rewriter, - MatchState &state, Operation *op) const { + LogicalResult findDestinationNode( + mlir::PatternRewriter &rewriter, MatchState &state, Operation *op) const { auto dq = state.dequantActivationOfBinOp; if (!dq) return rewriter.notifyMatchFailure( @@ -609,8 +602,8 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { } public: - LogicalResult matchAndRewrite(BinOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite( + BinOp op, PatternRewriter &rewriter) const override { // STEP 1: Match begin: Assuming only one user if (!op->hasOneUse()) { @@ -627,10 +620,10 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { // STEP 2 if (failed(match_binary_op(rewriter, state, op))) { - return rewriter.notifyMatchFailure( - op, " does not match to critieria to remove binary. Remove binary op " - "only if one of the dequantize linear input " - "has const scalar value "); + return rewriter.notifyMatchFailure(op, + " does not match to critieria to remove binary. Remove binary op " + "only if one of the dequantize linear input " + "has const scalar value "); } // STEP 3 @@ -661,13 +654,13 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { // Update zero-point at DQ.x Value xZp = dqDst.getXZeroPoint(); updateInitializer(rewriter, dqDst.getOperation(), xZp, - static_cast(state.newZp)); + static_cast(state.newZp)); } else if constexpr (std::is_same_v || std::is_same_v) { // Update scale at DQ.x Value xScale = dqDst.getXScale(); - updateInitializer(rewriter, dqDst.getOperation(), xScale, - state.newScale); + updateInitializer( + rewriter, dqDst.getOperation(), xScale, state.newScale); } } else if (auto qDst = llvm::dyn_cast(dst)) { if constexpr (std::is_same_v || @@ -675,13 +668,13 @@ struct FoldBinaryThroughQDQ : public OpRewritePattern { // Update zero-point at Q.y Value yZp = qDst.getYZeroPoint(); updateInitializer(rewriter, qDst.getOperation(), yZp, - static_cast(state.newZp)); + static_cast(state.newZp)); } else if constexpr (std::is_same_v || std::is_same_v) { // Update scale at Q.y Value yScale = qDst.getYScale(); - updateInitializer(rewriter, qDst.getOperation(), yScale, - state.newScale); + updateInitializer( + rewriter, qDst.getOperation(), yScale, state.newScale); } } else { return rewriter.notifyMatchFailure( @@ -721,9 +714,7 @@ struct FoldDQBinaryQPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FoldDQBinaryQPass) - StringRef getArgument() const final { - return "dq-binary-q-opt-onnx-to-onnx"; - } + StringRef getArgument() const final { return "dq-binary-q-opt-onnx-to-onnx"; } StringRef getDescription() const final { return "Fold Add/Sub/Mul/Div through Q/DQ by updating scale/zero_point, " "then remove trivial Q->DQ chains when safe."; @@ -734,7 +725,7 @@ struct FoldDQBinaryQPass RewritePatternSet patterns(&getContext()); patterns .add, FoldBinaryThroughQDQ, - FoldBinaryThroughQDQ, FoldBinaryThroughQDQ>( + FoldBinaryThroughQDQ, FoldBinaryThroughQDQ>( &getContext()); if (failed(applyPatternsGreedily(function, std::move(patterns)))) signalPassFailure(); From c2e367fc04ce12aaab0b58cba3b8a457dd9df979 Mon Sep 17 00:00:00 2001 From: sushmita Date: Tue, 11 Nov 2025 11:01:10 +0530 Subject: [PATCH 5/5] removed redundant tests --- test/mlir/onnx/onnx_remove_binary.mlir | 93 ++------------------------ 1 file changed, 6 insertions(+), 87 deletions(-) diff --git a/test/mlir/onnx/onnx_remove_binary.mlir b/test/mlir/onnx/onnx_remove_binary.mlir index fac2c9c2df..bd34a61726 100644 --- a/test/mlir/onnx/onnx_remove_binary.mlir +++ b/test/mlir/onnx/onnx_remove_binary.mlir @@ -29,7 +29,7 @@ // CHECK-NOT: "onnx.Mul" // ============================================================================ -// ===== CASE A: lhs = DQ, rhs = Const (fold into Q; update Q.y_zero_point) ===== +// CASE A: lhs = DQ, rhs = Const (fold into Q; update Q.y_zero_point) // ============================================================================ func.func @caseA_lhsDQ_rhsConst_foldIntoQ(%arg0: tensor<1x4xf32>) -> tensor<1x4xi8> { @@ -55,7 +55,7 @@ func.func @caseA_lhsDQ_rhsConst_foldIntoQ(%arg0: tensor<1x4xf32>) -> tensor<1x4x // CHECK: return %[[Q]] : tensor<1x4xi8> // ============================================================================ -// ===== CASE A-REV: rhs = DQ, lhs = Const (fold into Q; update Q.y_zero_point) ===== +// CASE A-REV: rhs = DQ, lhs = Const (fold into Q; update Q.y_zero_point) // ============================================================================ func.func @caseA_rev_rhsDQ_lhsConst_foldIntoQ(%arg0: tensor<1x4xf32>) -> tensor<1x4xi8> { @@ -81,7 +81,7 @@ func.func @caseA_rev_rhsDQ_lhsConst_foldIntoQ(%arg0: tensor<1x4xf32>) -> tensor< // CHECK: return %[[Q]] : tensor<1x4xi8> // ============================================================================ -// ===== CASE B: both inputs are DQ; constant via dq1 (fold into Q) ===== +// CASE B: both inputs are DQ; constant via dq1 (fold into Q) // ============================================================================ func.func @caseB_bothDQ_constViaDQ1_foldIntoQ(%arg0: tensor<1x4xf32>) -> tensor<1x4xi8> { @@ -107,7 +107,7 @@ func.func @caseB_bothDQ_constViaDQ1_foldIntoQ(%arg0: tensor<1x4xf32>) -> tensor< // CHECK: return %[[Q]] : tensor<1x4xi8> // ============================================================================ -// ===== CASE B with value-preserving link on constant side: Reshape(const_q) → DQ ===== +// CASE B with value-preserving link on constant side: Reshape(const_q) → DQ // ============================================================================ func.func @caseB_constViaReshape_foldIntoQ(%arg0: tensor<1x4xf32>) -> tensor<1x4xi8> { @@ -139,7 +139,8 @@ func.func @caseB_bothDQ_constViaDQ1_foldIntoQ(%arg0: tensor<1x4xf32>) -> tensor< // CHECK-NOT: onnx.Reshape // ============================================================================ -// ===== BRANCH-BEFORE: Q1 has another user (fold into DQ; update DQ.x_zero_point) ===== +// BRANCH-BEFORE: Q1 has another user (fold into DQ; update DQ.x_zero_point) +// Also checks for Removal of Q->DQ chain after the binary op // ============================================================================ func.func @branchBefore_foldIntoDQ(%arg0: tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor<1x4xi8>) { @@ -250,89 +251,7 @@ func.func @test_kval_0_dst_q_mul(%arg0: tensor<10x1xf32>) -> tensor<10x1xf32> { // CHECK-SAME: : (tensor<10x1xui16>, tensor, tensor) -> tensor<10x1xf32> // CHECK: return %[[DQOUT]] : tensor<10x1xf32> -// ============================================================================ -// k_value == 0 and (dst is QuantizeLinear) with a Div -// ============================================================================ - -func.func @test_kval_0_dst_q_div(%arg0: tensor<10x1xf32>) -> tensor<10x1xf32> { - %0 = onnx.Constant dense<0> : tensor - %1 = onnx.Constant dense<5.78499521E-6> : tensor - %2 = onnx.Constant dense<0> : tensor - %3 = onnx.Constant dense<0.00152590231> : tensor - %4 = onnx.Constant dense<0> : tensor - %5 = onnx.Constant dense<10> : tensor - %6 = onnx.Constant dense<1.000000e-01> : tensor - %7 = "onnx.DequantizeLinear"(%4, %3, %2) {axis = 1 : si64, block_size = 0 : si64} - : (tensor, tensor, tensor) -> tensor - %8 = "onnx.QuantizeLinear"(%arg0, %1, %0) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} - : (tensor<10x1xf32>, tensor, tensor) -> tensor<10x1xui16> - %9 = "onnx.DequantizeLinear"(%8, %1, %0) {axis = 1 : si64, block_size = 0 : si64} - : (tensor<10x1xui16>, tensor, tensor) -> tensor<10x1xf32> - %10 = "onnx.Div"(%9, %7) : (tensor<10x1xf32>, tensor) -> tensor<10x1xf32> - %11 = "onnx.QuantizeLinear"(%10, %6, %5) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} - : (tensor<10x1xf32>, tensor, tensor) -> tensor<10x1xui16> - %12 = "onnx.DequantizeLinear"(%11, %6, %5) {axis = 1 : si64, block_size = 0 : si64} - : (tensor<10x1xui16>, tensor, tensor) -> tensor<10x1xf32> - - return %12 : tensor<10x1xf32> -} -// CHECK-LABEL: func.func @test_kval_0_dst_q_div( -// CHECK-SAME: %arg0: tensor<10x1xf32>) -> tensor<10x1xf32> -// CHECK: %[[ZP:.*]] = onnx.Constant dense<10> : tensor -// CHECK: %[[S_DQ:.*]] = onnx.Constant dense<1.000000e-01> : tensor -// CHECK: %[[S_Q:.*]] = onnx.Constant dense<0.000000e+00> : tensor -// CHECK: %[[Q:.*]] = "onnx.QuantizeLinear"(%arg0, %[[S_Q]], %[[ZP]]) -// CHECK-SAME: {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} -// CHECK-SAME: : (tensor<10x1xf32>, tensor, tensor) -> tensor<10x1xui16> -// CHECK: %[[DQ:.*]] = "onnx.DequantizeLinear"(%[[Q]], %[[S_DQ]], %[[ZP]]) -// CHECK-SAME: {axis = 1 : si64, block_size = 0 : si64} -// CHECK-SAME: : (tensor<10x1xui16>, tensor, tensor) -> tensor<10x1xf32> -// CHECK-NOT: "onnx.Div" -// CHECK: return %[[DQ]] : tensor<10x1xf32> - -// ============================================================================ -// Test A: Fold happened into DQ → chainStartQ = Quantize AFTER the BinOp -// Expect: the Q→DQ pair AFTER the BinOp is removed by Remove_Q_Plus_DQ. -// ============================================================================ - -func.func @cleanup_qdq_after_binop_folded_into_dq(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // Activation path: Q_act -> DQ_act - %s_act = onnx.Constant dense<5.000000e-01> : tensor - %zp_act = onnx.Constant dense<0> : tensor - %q_act = "onnx.QuantizeLinear"(%arg0, %s_act, %zp_act) - : (tensor<4xf32>, tensor, tensor) -> tensor<4xi8> - %dq_act = "onnx.DequantizeLinear"(%q_act, %s_act, %zp_act) - : (tensor<4xi8>, tensor, tensor) -> tensor<4xf32> - %c_q = onnx.Constant dense<10> : tensor - %c_s = onnx.Constant dense<5.000000e+00> : tensor - %c_zp = onnx.Constant dense<0> : tensor - %dq_c = "onnx.DequantizeLinear"(%c_q, %c_s, %c_zp) - : (tensor, tensor, tensor) -> tensor - %add = "onnx.Add"(%dq_act, %dq_c) : (tensor<4xf32>, tensor) -> tensor<4xf32> - %s_out = onnx.Constant dense<1.000000e-01> : tensor - %zp_out = onnx.Constant dense<0> : tensor - %q_out = "onnx.QuantizeLinear"(%add, %s_out, %zp_out) - : (tensor<4xf32>, tensor, tensor) -> tensor<4xi8> - - %dq_out = "onnx.DequantizeLinear"(%q_out, %s_out, %zp_out) - : (tensor<4xi8>, tensor, tensor) -> tensor<4xf32> - return %dq_out : tensor<4xf32> -} - -// CHECK-LABEL: func.func @cleanup_qdq_after_binop_folded_into_dq( -// CHECK-SAME: %arg0: tensor<4xf32>) -> tensor<4xf32> -// CHECK: %[[S:.*]] = onnx.Constant dense<1.000000e-01> : tensor -// CHECK: %[[ZP_DQ:.*]] = onnx.Constant dense<0> : tensor -// CHECK: %[[ZP_Q:.*]] = onnx.Constant dense<-1> : tensor -// CHECK: %[[Q:.*]] = "onnx.QuantizeLinear"(%arg0, %[[S]], %[[ZP_Q]]) -// CHECK-SAME: {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} -// CHECK-SAME: : (tensor<4xf32>, tensor, tensor) -> tensor<4xi8> -// CHECK: %[[DQ:.*]] = "onnx.DequantizeLinear"(%[[Q]], %[[S]], %[[ZP_DQ]]) -// CHECK-SAME: {axis = 1 : si64, block_size = 0 : si64} -// CHECK-SAME: : (tensor<4xi8>, tensor, tensor) -> tensor<4xf32> -// CHECK-NOT: "onnx.Add" -// CHECK: return %[[DQ]] : tensor<4xf32> // ============================================================================