From df2d910b9c924ea07050b51ce453d46dcc14b478 Mon Sep 17 00:00:00 2001 From: Rachit Gupta Date: Thu, 18 Sep 2025 11:26:13 -0500 Subject: [PATCH 1/8] Add qdq nodes around non qdq ops --- src/Compiler/OnnxToMlirPasses.cpp | 3 + src/Compiler/OnnxToMlirPasses.hpp | 1 + src/Dialect/ONNX/Transforms/AddQDQOpt.cpp | 185 +++++++++++++++++++++ src/Dialect/ONNX/Transforms/CMakeLists.txt | 3 +- src/Pass/Passes.hpp | 1 + src/Tools/onnx-mlir-opt/RegisterPasses.cpp | 4 + test/mlir/onnx/onnx_add_qdq.mlir | 41 +++++ 7 files changed, 237 insertions(+), 1 deletion(-) create mode 100644 src/Dialect/ONNX/Transforms/AddQDQOpt.cpp create mode 100644 test/mlir/onnx/onnx_add_qdq.mlir diff --git a/src/Compiler/OnnxToMlirPasses.cpp b/src/Compiler/OnnxToMlirPasses.cpp index 291ca8e464..5a58c38f57 100644 --- a/src/Compiler/OnnxToMlirPasses.cpp +++ b/src/Compiler/OnnxToMlirPasses.cpp @@ -97,6 +97,9 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU, if (opts.enableRemoveDqQOp) pm.addPass(createQDQOptONNXToONNXPass()); + if (opts.enableAddQDQOp) + pm.addPass(createMissingQDQAroundOpOptONNXToONNXPass()); + // One more call to ONNX shape inference/canonicalization/... to update // shape if possible. if (opts.enableONNXHybridPass) { diff --git a/src/Compiler/OnnxToMlirPasses.hpp b/src/Compiler/OnnxToMlirPasses.hpp index a7532ec926..6c2d0a7b4a 100644 --- a/src/Compiler/OnnxToMlirPasses.hpp +++ b/src/Compiler/OnnxToMlirPasses.hpp @@ -19,6 +19,7 @@ struct OnnxToMlirOptions { bool enableRemoveDqQOp = true; bool enableRemoveDqQAroundOp = true; bool enableRemoveBinary = false; + bool enableAddQDQOp = false; bool disableRecomposeOption = false; bool enableONNXHybridPass = true; diff --git a/src/Dialect/ONNX/Transforms/AddQDQOpt.cpp b/src/Dialect/ONNX/Transforms/AddQDQOpt.cpp new file mode 100644 index 0000000000..e93008713c --- /dev/null +++ b/src/Dialect/ONNX/Transforms/AddQDQOpt.cpp @@ -0,0 +1,185 @@ +//===- foldDqBinaryQPattern.cpp - Remove DQ-Binary-Q chains -----*- C++ -*-===// +// (c) Copyright 2022 - 2025 Advanced Micro Devices, Inc. All Rights Reserved. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" +#include "src/Dialect/ONNX/ONNXOps.hpp" +#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" +#include "src/Pass/Passes.hpp" +#include "llvm/ADT/STLExtras.h" +#include +#include +#include + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; +using namespace onnx_mlir; + +namespace { + +struct AddQDQAroundOp : public PassWrapper> { + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddQDQAroundOp) + + StringRef getArgument() const final { return "add-qdq-around-op"; } + StringRef getDescription() const final { + return "Add Q ,DQ around nodes which are missing them"; + } + + Type extractZeroPointType(OpBuilder &builder,func::FuncOp &func){ + Type zpElemType = builder.getIntegerType(8); // default int8 + + func.walk([&](Operation *op) -> WalkResult { + if (auto q = mlir::dyn_cast(op)) { + Value zp = q.getYZeroPoint(); + if (auto zpShaped = llvm::dyn_cast(zp.getType())) { + zpElemType = zpShaped.getElementType(); + return WalkResult::interrupt(); + } + } else if (auto dq = mlir::dyn_cast(op)) { + Value zp = dq.getXZeroPoint(); + if (auto zpShaped = llvm::dyn_cast(zp.getType())) { + zpElemType = zpShaped.getElementType(); + return WalkResult::interrupt(); + } + } + return WalkResult::advance(); + }); + return zpElemType; + + } + + void runOnOperation() override { + func::FuncOp func = getOperation(); + MLIRContext *ctx = &getContext(); + OpBuilder builder(ctx); + + Block &entryBlock = func.getBody().front(); + Location entryLoc = func.getLoc(); + + + RankedTensorType scaleType = RankedTensorType::get({}, builder.getF32Type()); + DenseElementsAttr scaleAttr = DenseElementsAttr::get(scaleType, {1.0f}); + OperationState scaleState(entryLoc, "onnx.Constant"); + scaleState.addAttribute("value", scaleAttr); + scaleState.addTypes(scaleAttr.getType()); + Operation *scaleOp = Operation::create(scaleState); + entryBlock.getOperations().insert(entryBlock.begin(), scaleOp); + Value scaleVal = scaleOp->getResult(0); + + Type zpEleType = extractZeroPointType(builder, func); + RankedTensorType zpType = RankedTensorType::get({}, zpEleType); + DenseElementsAttr zpAttr; + if (zpEleType.isa()){ + auto intType = mlir::dyn_cast(zpEleType); + unsigned width = intType.getWidth(); + bool isSigned = intType.isSignedInteger(); + + if (width == 8) { + if (isSigned) { + zpAttr = DenseElementsAttr::get( + zpType, + {static_cast(0)}); + } else { + zpAttr = DenseElementsAttr::get( + zpType, + {static_cast(0)}); + } + } else if (width == 16) { + if (isSigned) { + zpAttr = DenseElementsAttr::get( + zpType, + {static_cast(0)}); + } else { + zpAttr = DenseElementsAttr::get( + zpType, + {static_cast(0)}); + } + } else if (width == 32) { + if (isSigned) { + zpAttr = DenseElementsAttr::get( + zpType, + {static_cast(0)}); + } else { + zpAttr = DenseElementsAttr::get( + zpType, + {static_cast(0)}); + } + } else { + // fallback: default int8 zero-point + zpAttr = DenseElementsAttr::get( + zpType, + {static_cast(0)}); + } + } else { + // fallback if not integer + zpAttr = DenseElementsAttr::get( + zpType, + {static_cast(0)}); + } + + + OperationState zpState(entryLoc, "onnx.Constant"); + zpState.addAttribute("value", zpAttr); + zpState.addTypes(zpAttr.getType()); + Operation *zpOp = Operation::create(zpState); + entryBlock.getOperations().insert(entryBlock.begin(), zpOp); + Value zpVal = zpOp->getResult(0); + + llvm::SmallDenseMap producerToDQ; + + for (Operation &opRef : llvm::make_early_inc_range(func.getOps())) { + Operation *op = &opRef; + + if (isa(op)) + continue; + + Location loc = op->getLoc(); + + for (Value operand : op->getOperands()) { + Operation *def = operand.getDefiningOp(); + if (def && isa(def)) continue; + + auto it = producerToDQ.find(operand); + if (it != producerToDQ.end()) { + op->replaceUsesOfWith(operand, it->second); + continue; + } + + builder.setInsertionPoint(op); + + ShapedType inShaped = llvm::dyn_cast(operand.getType()); + + Type qResultType = operand.getType(); + if (inShaped) + qResultType = RankedTensorType::get(inShaped.getShape(), zpEleType); + + auto q = builder.create(loc, qResultType, operand, scaleVal, zpVal); + auto dq = builder.create(loc, operand.getType(), q.getResult(), scaleVal, zpVal); + + producerToDQ.try_emplace(operand, dq.getResult()); + op->replaceUsesOfWith(operand, dq.getResult()); + } + } + } +}; +} // namespace + +namespace onnx_mlir { +std::unique_ptr createMissingQDQAroundOpOptONNXToONNXPass() { + return std::make_unique(); +} +} // namespace onnx_mlir \ No newline at end of file diff --git a/src/Dialect/ONNX/Transforms/CMakeLists.txt b/src/Dialect/ONNX/Transforms/CMakeLists.txt index 7862443935..3b4e72d6a1 100644 --- a/src/Dialect/ONNX/Transforms/CMakeLists.txt +++ b/src/Dialect/ONNX/Transforms/CMakeLists.txt @@ -44,7 +44,8 @@ add_onnx_mlir_library(OMONNXRewrite ConstProp.cpp QDQAroundOpOpt.cpp QDQOpt.cpp - DQBinaryQOpt.cpp + DQBinaryQOpt.cpp + AddQDQopt.cpp ConvOpt.cpp Decompose.cpp DecomposeEinsum.cpp diff --git a/src/Pass/Passes.hpp b/src/Pass/Passes.hpp index 8e4855e6dc..f0420d4b63 100644 --- a/src/Pass/Passes.hpp +++ b/src/Pass/Passes.hpp @@ -58,6 +58,7 @@ std::unique_ptr createQDQAroundOpOptONNXToONNXPass(); std::unique_ptr createQDQOptONNXToONNXPass(); std::unique_ptr createFoldDQBinaryQPass(); +std::unique_ptr createMissingQDQAroundOpOptONNXToONNXPass(); /// Pass for instrument the ops in specific stage. std::unique_ptr createInstrumentPass(); diff --git a/src/Tools/onnx-mlir-opt/RegisterPasses.cpp b/src/Tools/onnx-mlir-opt/RegisterPasses.cpp index ba6635b466..cc0dadc63c 100644 --- a/src/Tools/onnx-mlir-opt/RegisterPasses.cpp +++ b/src/Tools/onnx-mlir-opt/RegisterPasses.cpp @@ -82,6 +82,10 @@ void registerOMPasses(int optLevel) { mlir::registerPass([]() -> std::unique_ptr { return createFoldDQBinaryQPass(); }); + + mlir::registerPass([]() -> std::unique_ptr { + return createMissingQDQAroundOpOptONNXToONNXPass(); + }); mlir::registerPass( []() -> std::unique_ptr { return createInstrumentPass(); }); diff --git a/test/mlir/onnx/onnx_add_qdq.mlir b/test/mlir/onnx/onnx_add_qdq.mlir new file mode 100644 index 0000000000..213a0b16c2 --- /dev/null +++ b/test/mlir/onnx/onnx_add_qdq.mlir @@ -0,0 +1,41 @@ +// RUN: onnx-mlir-opt --add-qdq-around-op %s | FileCheck %s + +func.func @test_inserted_qdq(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = onnx.Constant {value = dense<1.0> : tensor<2x2xf32>} : tensor<2x2xf32> + %init = onnx.Constant dense<2.0> : tensor + %init_div = onnx.Constant dense<4.0> : tensor + %cst_qdq_zp = onnx.Constant dense<0> : tensor + %cst_qdq_s = onnx.Constant dense<1.52590219E-5> : tensor + %0 = "onnx.Add"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + %1 = "onnx.Mul"(%0, %init) : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + %2 = "onnx.Sub"(%0, %init) : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + %3 = "onnx.QuantizeLinear"(%0, %cst_qdq_s, %cst_qdq_zp) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<2x2xf32>, tensor, tensor) -> tensor<2x2xi16> + %4 = "onnx.DequantizeLinear"(%3, %cst_qdq_s, %cst_qdq_zp) {axis = 1 : si64, block_size = 0 : si64} : (tensor<2x2xi16>, tensor, tensor) -> tensor<2x2xf32> + %5 = "onnx.Div"(%4, %init_div) : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + %6 = "onnx.Add"(%1, %2) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + + return %6 : tensor<2x2xf32> +} + +// CHECK-LABEL: func.func @test_inserted_qdq +// CHECK: onnx.QuantizeLinear +// CHECK: onnx.DequantizeLinear +// CHECK: onnx.QuantizeLinear +// CHECK: onnx.DequantizeLinear +// CHECK: onnx.Add +// CHECK: onnx.QuantizeLinear +// CHECK: onnx.DequantizeLinear +// CHECK: onnx.QuantizeLinear +// CHECK: onnx.DequantizeLinear +// CHECK: onnx.Mul +// CHECK: onnx.Sub +// CHECK: onnx.QuantizeLinear +// CHECK: onnx.DequantizeLinear +// CHECK: onnx.Div +// CHECK: onnx.QuantizeLinear +// CHECK: onnx.DequantizeLinear +// CHECK: onnx.QuantizeLinear +// CHECK: onnx.DequantizeLinear +// CHECK: onnx.Add +// CHECK: onnx.QuantizeLinear +// CHECK: onnx.DequantizeLinear \ No newline at end of file From f8813100d4ecc8abf050e10df31502580cbc8560 Mon Sep 17 00:00:00 2001 From: Rachit Gupta Date: Thu, 18 Sep 2025 11:38:49 -0500 Subject: [PATCH 2/8] clean up --- src/Dialect/ONNX/Transforms/AddQDQOpt.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Dialect/ONNX/Transforms/AddQDQOpt.cpp b/src/Dialect/ONNX/Transforms/AddQDQOpt.cpp index e93008713c..9798c7698f 100644 --- a/src/Dialect/ONNX/Transforms/AddQDQOpt.cpp +++ b/src/Dialect/ONNX/Transforms/AddQDQOpt.cpp @@ -1,4 +1,4 @@ -//===- foldDqBinaryQPattern.cpp - Remove DQ-Binary-Q chains -----*- C++ -*-===// + // (c) Copyright 2022 - 2025 Advanced Micro Devices, Inc. All Rights Reserved. // //===----------------------------------------------------------------------===// From bf2a007f4fe3b22451450644abd58e85186b95ca Mon Sep 17 00:00:00 2001 From: Rachit Gupta Date: Fri, 19 Sep 2025 03:30:37 -0500 Subject: [PATCH 3/8] Check for only float input and provide qdq around scalar --- src/Dialect/ONNX/Transforms/AddQDQOpt.cpp | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/Dialect/ONNX/Transforms/AddQDQOpt.cpp b/src/Dialect/ONNX/Transforms/AddQDQOpt.cpp index 9798c7698f..6cf2202467 100644 --- a/src/Dialect/ONNX/Transforms/AddQDQOpt.cpp +++ b/src/Dialect/ONNX/Transforms/AddQDQOpt.cpp @@ -83,7 +83,7 @@ struct AddQDQAroundOp : public PassWrapper()){ + if (isa(zpEleType)){ auto intType = mlir::dyn_cast(zpEleType); unsigned width = intType.getWidth(); bool isSigned = intType.isSignedInteger(); @@ -164,14 +164,20 @@ struct AddQDQAroundOp : public PassWrapper(operand.getType()); Type qResultType = operand.getType(); - if (inShaped) - qResultType = RankedTensorType::get(inShaped.getShape(), zpEleType); + if (isa(operand.getType()) || (inShaped && isa(inShaped.getElementType()))){ + if (inShaped) + qResultType = RankedTensorType::get(inShaped.getShape(), zpEleType); + else + qResultType = RankedTensorType::get({}, zpEleType); + + auto q = builder.create(loc, qResultType, operand, scaleVal, zpVal); + auto dq = builder.create(loc, operand.getType(), q.getResult(), scaleVal, zpVal); + + producerToDQ.try_emplace(operand, dq.getResult()); + op->replaceUsesOfWith(operand, dq.getResult()); - auto q = builder.create(loc, qResultType, operand, scaleVal, zpVal); - auto dq = builder.create(loc, operand.getType(), q.getResult(), scaleVal, zpVal); + } - producerToDQ.try_emplace(operand, dq.getResult()); - op->replaceUsesOfWith(operand, dq.getResult()); } } } From 54d426d646021af32c74587ea49814852996420f Mon Sep 17 00:00:00 2001 From: Rachit Gupta Date: Fri, 19 Sep 2025 04:17:44 -0500 Subject: [PATCH 4/8] lint fix --- src/Compiler/OnnxToMlirPasses.cpp | 2 +- src/Dialect/ONNX/Transforms/AddQDQOpt.cpp | 183 ++++++++++----------- src/Tools/onnx-mlir-opt/RegisterPasses.cpp | 4 +- 3 files changed, 87 insertions(+), 102 deletions(-) diff --git a/src/Compiler/OnnxToMlirPasses.cpp b/src/Compiler/OnnxToMlirPasses.cpp index 5a58c38f57..ccba81eb6c 100644 --- a/src/Compiler/OnnxToMlirPasses.cpp +++ b/src/Compiler/OnnxToMlirPasses.cpp @@ -98,7 +98,7 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU, pm.addPass(createQDQOptONNXToONNXPass()); if (opts.enableAddQDQOp) - pm.addPass(createMissingQDQAroundOpOptONNXToONNXPass()); + pm.addPass(createMissingQDQAroundOpOptONNXToONNXPass()); // One more call to ONNX shape inference/canonicalization/... to update // shape if possible. diff --git a/src/Dialect/ONNX/Transforms/AddQDQOpt.cpp b/src/Dialect/ONNX/Transforms/AddQDQOpt.cpp index 6cf2202467..4d59d09810 100644 --- a/src/Dialect/ONNX/Transforms/AddQDQOpt.cpp +++ b/src/Dialect/ONNX/Transforms/AddQDQOpt.cpp @@ -30,7 +30,8 @@ using namespace onnx_mlir; namespace { -struct AddQDQAroundOp : public PassWrapper> { +struct AddQDQAroundOp + : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddQDQAroundOp) @@ -38,30 +39,29 @@ struct AddQDQAroundOp : public PassWrapper WalkResult { - if (auto q = mlir::dyn_cast(op)) { - Value zp = q.getYZeroPoint(); - if (auto zpShaped = llvm::dyn_cast(zp.getType())) { - zpElemType = zpShaped.getElementType(); - return WalkResult::interrupt(); - } - } else if (auto dq = mlir::dyn_cast(op)) { - Value zp = dq.getXZeroPoint(); - if (auto zpShaped = llvm::dyn_cast(zp.getType())) { - zpElemType = zpShaped.getElementType(); - return WalkResult::interrupt(); - } + if (auto q = mlir::dyn_cast(op)) { + Value zp = q.getYZeroPoint(); + if (auto zpShaped = llvm::dyn_cast(zp.getType())) { + zpElemType = zpShaped.getElementType(); + return WalkResult::interrupt(); + } + } else if (auto dq = mlir::dyn_cast(op)) { + Value zp = dq.getXZeroPoint(); + if (auto zpShaped = llvm::dyn_cast(zp.getType())) { + zpElemType = zpShaped.getElementType(); + return WalkResult::interrupt(); } - return WalkResult::advance(); + } + return WalkResult::advance(); }); return zpElemType; - } - + void runOnOperation() override { func::FuncOp func = getOperation(); MLIRContext *ctx = &getContext(); @@ -70,8 +70,8 @@ struct AddQDQAroundOp : public PassWrapper(zpEleType)){ - auto intType = mlir::dyn_cast(zpEleType); - unsigned width = intType.getWidth(); - bool isSigned = intType.isSignedInteger(); - - if (width == 8) { - if (isSigned) { - zpAttr = DenseElementsAttr::get( - zpType, - {static_cast(0)}); - } else { - zpAttr = DenseElementsAttr::get( - zpType, - {static_cast(0)}); - } - } else if (width == 16) { - if (isSigned) { - zpAttr = DenseElementsAttr::get( - zpType, - {static_cast(0)}); - } else { - zpAttr = DenseElementsAttr::get( - zpType, - {static_cast(0)}); - } - } else if (width == 32) { - if (isSigned) { - zpAttr = DenseElementsAttr::get( - zpType, - {static_cast(0)}); - } else { - zpAttr = DenseElementsAttr::get( - zpType, - {static_cast(0)}); - } + if (isa(zpEleType)) { + auto intType = mlir::dyn_cast(zpEleType); + unsigned width = intType.getWidth(); + bool isSigned = intType.isSignedInteger(); + + if (width == 8) { + if (isSigned) { + zpAttr = DenseElementsAttr::get(zpType, {static_cast(0)}); } else { - // fallback: default int8 zero-point - zpAttr = DenseElementsAttr::get( - zpType, - {static_cast(0)}); + zpAttr = DenseElementsAttr::get(zpType, {static_cast(0)}); } + } else if (width == 16) { + if (isSigned) { + zpAttr = DenseElementsAttr::get(zpType, {static_cast(0)}); + } else { + zpAttr = DenseElementsAttr::get(zpType, {static_cast(0)}); + } + } else if (width == 32) { + if (isSigned) { + zpAttr = DenseElementsAttr::get(zpType, {static_cast(0)}); + } else { + zpAttr = DenseElementsAttr::get(zpType, {static_cast(0)}); + } + } else { + // fallback: default int8 zero-point + zpAttr = DenseElementsAttr::get(zpType, {static_cast(0)}); + } } else { - // fallback if not integer - zpAttr = DenseElementsAttr::get( - zpType, - {static_cast(0)}); + // fallback if not integer + zpAttr = DenseElementsAttr::get(zpType, {static_cast(0)}); } - - + OperationState zpState(entryLoc, "onnx.Constant"); zpState.addAttribute("value", zpAttr); zpState.addTypes(zpAttr.getType()); @@ -142,45 +125,47 @@ struct AddQDQAroundOp : public PassWrapper producerToDQ; for (Operation &opRef : llvm::make_early_inc_range(func.getOps())) { - Operation *op = &opRef; + Operation *op = &opRef; - if (isa(op)) + if (isa(op)) continue; - Location loc = op->getLoc(); - - for (Value operand : op->getOperands()) { - Operation *def = operand.getDefiningOp(); - if (def && isa(def)) continue; - - auto it = producerToDQ.find(operand); - if (it != producerToDQ.end()) { - op->replaceUsesOfWith(operand, it->second); - continue; - } - - builder.setInsertionPoint(op); - - ShapedType inShaped = llvm::dyn_cast(operand.getType()); - - Type qResultType = operand.getType(); - if (isa(operand.getType()) || (inShaped && isa(inShaped.getElementType()))){ - if (inShaped) - qResultType = RankedTensorType::get(inShaped.getShape(), zpEleType); - else - qResultType = RankedTensorType::get({}, zpEleType); - - auto q = builder.create(loc, qResultType, operand, scaleVal, zpVal); - auto dq = builder.create(loc, operand.getType(), q.getResult(), scaleVal, zpVal); - - producerToDQ.try_emplace(operand, dq.getResult()); - op->replaceUsesOfWith(operand, dq.getResult()); - - } - + Location loc = op->getLoc(); + + for (Value operand : op->getOperands()) { + Operation *def = operand.getDefiningOp(); + if (def && isa(def)) + continue; + + auto it = producerToDQ.find(operand); + if (it != producerToDQ.end()) { + op->replaceUsesOfWith(operand, it->second); + continue; } - } + + builder.setInsertionPoint(op); + + ShapedType inShaped = llvm::dyn_cast(operand.getType()); + + Type qResultType = operand.getType(); + if (isa(operand.getType()) || + (inShaped && isa(inShaped.getElementType()))) { + if (inShaped) + qResultType = RankedTensorType::get(inShaped.getShape(), zpEleType); + else + qResultType = RankedTensorType::get({}, zpEleType); + + auto q = builder.create( + loc, qResultType, operand, scaleVal, zpVal); + auto dq = builder.create( + loc, operand.getType(), q.getResult(), scaleVal, zpVal); + + producerToDQ.try_emplace(operand, dq.getResult()); + op->replaceUsesOfWith(operand, dq.getResult()); + } + } } + } }; } // namespace diff --git a/src/Tools/onnx-mlir-opt/RegisterPasses.cpp b/src/Tools/onnx-mlir-opt/RegisterPasses.cpp index cc0dadc63c..50fd89cb80 100644 --- a/src/Tools/onnx-mlir-opt/RegisterPasses.cpp +++ b/src/Tools/onnx-mlir-opt/RegisterPasses.cpp @@ -82,10 +82,10 @@ void registerOMPasses(int optLevel) { mlir::registerPass([]() -> std::unique_ptr { return createFoldDQBinaryQPass(); }); - + mlir::registerPass([]() -> std::unique_ptr { return createMissingQDQAroundOpOptONNXToONNXPass(); - }); + }); mlir::registerPass( []() -> std::unique_ptr { return createInstrumentPass(); }); From 8bebed4bb0ee52b0a6598a83130c575ea8c392e3 Mon Sep 17 00:00:00 2001 From: Rachit Gupta Date: Fri, 19 Sep 2025 04:43:02 -0500 Subject: [PATCH 5/8] filename mismatch fix --- src/Dialect/ONNX/Transforms/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Dialect/ONNX/Transforms/CMakeLists.txt b/src/Dialect/ONNX/Transforms/CMakeLists.txt index 3b4e72d6a1..2ba5887acc 100644 --- a/src/Dialect/ONNX/Transforms/CMakeLists.txt +++ b/src/Dialect/ONNX/Transforms/CMakeLists.txt @@ -45,7 +45,7 @@ add_onnx_mlir_library(OMONNXRewrite QDQAroundOpOpt.cpp QDQOpt.cpp DQBinaryQOpt.cpp - AddQDQopt.cpp + AddQDQOpt.cpp ConvOpt.cpp Decompose.cpp DecomposeEinsum.cpp From daf1dc4b2003fb63c0ba9a7b149052ded068ee1a Mon Sep 17 00:00:00 2001 From: Rachit Gupta Date: Mon, 6 Oct 2025 07:22:54 -0500 Subject: [PATCH 6/8] Restrict qdq addition to dma ops only --- src/Dialect/ONNX/Transforms/AddQDQOpt.cpp | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/Dialect/ONNX/Transforms/AddQDQOpt.cpp b/src/Dialect/ONNX/Transforms/AddQDQOpt.cpp index 4d59d09810..25685e7657 100644 --- a/src/Dialect/ONNX/Transforms/AddQDQOpt.cpp +++ b/src/Dialect/ONNX/Transforms/AddQDQOpt.cpp @@ -62,6 +62,15 @@ struct AddQDQAroundOp return zpElemType; } + bool isDMAOp(Operation *op) { + return isa(op); + } + void runOnOperation() override { func::FuncOp func = getOperation(); MLIRContext *ctx = &getContext(); @@ -127,7 +136,9 @@ struct AddQDQAroundOp for (Operation &opRef : llvm::make_early_inc_range(func.getOps())) { Operation *op = &opRef; - if (isa(op)) + if ((isa( + op)) || + !(isDMAOp(op))) continue; Location loc = op->getLoc(); From 07eb2ac3a1aebbe5bb2843d997f518820a6cb8fc Mon Sep 17 00:00:00 2001 From: Rachit Gupta Date: Tue, 7 Oct 2025 04:25:41 -0500 Subject: [PATCH 7/8] updated test --- test/mlir/onnx/onnx_add_qdq.mlir | 68 +++++++++++++++++++------------- 1 file changed, 41 insertions(+), 27 deletions(-) diff --git a/test/mlir/onnx/onnx_add_qdq.mlir b/test/mlir/onnx/onnx_add_qdq.mlir index 213a0b16c2..6b2a051dbd 100644 --- a/test/mlir/onnx/onnx_add_qdq.mlir +++ b/test/mlir/onnx/onnx_add_qdq.mlir @@ -1,41 +1,55 @@ // RUN: onnx-mlir-opt --add-qdq-around-op %s | FileCheck %s - -func.func @test_inserted_qdq(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - %cst = onnx.Constant {value = dense<1.0> : tensor<2x2xf32>} : tensor<2x2xf32> - %init = onnx.Constant dense<2.0> : tensor - %init_div = onnx.Constant dense<4.0> : tensor - %cst_qdq_zp = onnx.Constant dense<0> : tensor - %cst_qdq_s = onnx.Constant dense<1.52590219E-5> : tensor - %0 = "onnx.Add"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - %1 = "onnx.Mul"(%0, %init) : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> - %2 = "onnx.Sub"(%0, %init) : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> - %3 = "onnx.QuantizeLinear"(%0, %cst_qdq_s, %cst_qdq_zp) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<2x2xf32>, tensor, tensor) -> tensor<2x2xi16> - %4 = "onnx.DequantizeLinear"(%3, %cst_qdq_s, %cst_qdq_zp) {axis = 1 : si64, block_size = 0 : si64} : (tensor<2x2xi16>, tensor, tensor) -> tensor<2x2xf32> - %5 = "onnx.Div"(%4, %init_div) : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> - %6 = "onnx.Add"(%1, %2) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - return %6 : tensor<2x2xf32> +func.func @test_inserted_qdq(% arg0 : tensor<2x2xf32>)->tensor<2x2xf32> { + % shape = + onnx.Constant{value = dense<[ 2, 2 ]> : tensor<2xi64>} : tensor<2xi64> % + cst = onnx. + Constant{value = dense<1.0> : tensor<2x2xf32>} : tensor<2x2xf32> % + init = + onnx.Constant dense<2.0> : tensor % cst_qdq_zp = + onnx.Constant dense<0> : tensor % cst_qdq_s = + onnx.Constant dense<1.52590219E-5> : tensor % 0 = + "onnx.Add"(% arg0, % cst) + : (tensor<2x2xf32>, tensor<2x2xf32>)->tensor<2x2xf32> % + 1 = "onnx.Mul"(% 0, % init) + : (tensor<2x2xf32>, tensor)->tensor<2x2xf32> % + 2 = "onnx.Reshape"(% 1, % shape) + : (tensor<2x2xf32>, tensor<2xi64>)->tensor<2x2xf32> % + 3 = "onnx.QuantizeLinear"( + % 0, % cst_qdq_s, % cst_qdq_zp){ + axis = 1 : si64, + block_size = 0 : si64, + output_dtype = 0 : si64, + saturate = 1 : si64 + } + : (tensor<2x2xf32>, tensor, tensor) + ->tensor<2x2xi16> % 4 = "onnx.DequantizeLinear"( + % 3, % cst_qdq_s, % cst_qdq_zp){ + axis = 1 : si64, + block_size = 0 : si64 + } + : (tensor<2x2xi16>, tensor, tensor) + ->tensor<2x2xf32> % 5 = "onnx.Reshape"(% 4, % shape) + : (tensor<2x2xf32>, tensor<2xi64>)->tensor<2x2xf32> % + 6 = "onnx.Transpose"(% 0){perm = [ 0, 1 ]} + : (tensor<2x2xf32>) + ->tensor<2x2xf32> % 7 = "onnx.Add"(% 5, % 6) + : (tensor<2x2xf32>, tensor<2x2xf32>) + ->tensor<2x2xf32> + + return % 7 : tensor<2x2xf32> } // CHECK-LABEL: func.func @test_inserted_qdq -// CHECK: onnx.QuantizeLinear -// CHECK: onnx.DequantizeLinear -// CHECK: onnx.QuantizeLinear -// CHECK: onnx.DequantizeLinear // CHECK: onnx.Add -// CHECK: onnx.QuantizeLinear -// CHECK: onnx.DequantizeLinear -// CHECK: onnx.QuantizeLinear -// CHECK: onnx.DequantizeLinear // CHECK: onnx.Mul -// CHECK: onnx.Sub // CHECK: onnx.QuantizeLinear // CHECK: onnx.DequantizeLinear -// CHECK: onnx.Div +// CHECK: onnx.Reshape // CHECK: onnx.QuantizeLinear // CHECK: onnx.DequantizeLinear +// CHECK: onnx.Reshape // CHECK: onnx.QuantizeLinear // CHECK: onnx.DequantizeLinear +// CHECK: onnx.Transpose // CHECK: onnx.Add -// CHECK: onnx.QuantizeLinear -// CHECK: onnx.DequantizeLinear \ No newline at end of file From 6e536c310fbe55fc84ed1c12bb5a831d92b2a4db Mon Sep 17 00:00:00 2001 From: Rachit Gupta Date: Tue, 7 Oct 2025 04:38:12 -0500 Subject: [PATCH 8/8] updated test formatting --- test/mlir/onnx/onnx_add_qdq.mlir | 54 ++++++++++---------------------- 1 file changed, 16 insertions(+), 38 deletions(-) diff --git a/test/mlir/onnx/onnx_add_qdq.mlir b/test/mlir/onnx/onnx_add_qdq.mlir index 6b2a051dbd..8a0c63ee19 100644 --- a/test/mlir/onnx/onnx_add_qdq.mlir +++ b/test/mlir/onnx/onnx_add_qdq.mlir @@ -1,43 +1,21 @@ // RUN: onnx-mlir-opt --add-qdq-around-op %s | FileCheck %s + +func.func @test_inserted_qdq(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + %shape = onnx.Constant {value = dense<[2, 2]> : tensor<2xi64>} : tensor<2xi64> + %cst = onnx.Constant {value = dense<1.0> : tensor<2x2xf32>} : tensor<2x2xf32> + %init = onnx.Constant dense<2.0> : tensor + %cst_qdq_zp = onnx.Constant dense<0> : tensor + %cst_qdq_s = onnx.Constant dense<1.52590219E-5> : tensor + %0 = "onnx.Add"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + %1 = "onnx.Mul"(%0, %init) : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + %2 = "onnx.Reshape"(%1, %shape) : (tensor<2x2xf32>, tensor<2xi64>) -> tensor<2x2xf32> + %3 = "onnx.QuantizeLinear"(%0, %cst_qdq_s, %cst_qdq_zp) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<2x2xf32>, tensor, tensor) -> tensor<2x2xi16> + %4 = "onnx.DequantizeLinear"(%3, %cst_qdq_s, %cst_qdq_zp) {axis = 1 : si64, block_size = 0 : si64} : (tensor<2x2xi16>, tensor, tensor) -> tensor<2x2xf32> + %5 = "onnx.Reshape"(%4, %shape) : (tensor<2x2xf32>, tensor<2xi64>) -> tensor<2x2xf32> + %6 = "onnx.Transpose"(%0) {perm = [0, 1]} : (tensor<2x2xf32>) -> tensor<2x2xf32> + %7 = "onnx.Add"(%5, %6) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> -func.func @test_inserted_qdq(% arg0 : tensor<2x2xf32>)->tensor<2x2xf32> { - % shape = - onnx.Constant{value = dense<[ 2, 2 ]> : tensor<2xi64>} : tensor<2xi64> % - cst = onnx. - Constant{value = dense<1.0> : tensor<2x2xf32>} : tensor<2x2xf32> % - init = - onnx.Constant dense<2.0> : tensor % cst_qdq_zp = - onnx.Constant dense<0> : tensor % cst_qdq_s = - onnx.Constant dense<1.52590219E-5> : tensor % 0 = - "onnx.Add"(% arg0, % cst) - : (tensor<2x2xf32>, tensor<2x2xf32>)->tensor<2x2xf32> % - 1 = "onnx.Mul"(% 0, % init) - : (tensor<2x2xf32>, tensor)->tensor<2x2xf32> % - 2 = "onnx.Reshape"(% 1, % shape) - : (tensor<2x2xf32>, tensor<2xi64>)->tensor<2x2xf32> % - 3 = "onnx.QuantizeLinear"( - % 0, % cst_qdq_s, % cst_qdq_zp){ - axis = 1 : si64, - block_size = 0 : si64, - output_dtype = 0 : si64, - saturate = 1 : si64 - } - : (tensor<2x2xf32>, tensor, tensor) - ->tensor<2x2xi16> % 4 = "onnx.DequantizeLinear"( - % 3, % cst_qdq_s, % cst_qdq_zp){ - axis = 1 : si64, - block_size = 0 : si64 - } - : (tensor<2x2xi16>, tensor, tensor) - ->tensor<2x2xf32> % 5 = "onnx.Reshape"(% 4, % shape) - : (tensor<2x2xf32>, tensor<2xi64>)->tensor<2x2xf32> % - 6 = "onnx.Transpose"(% 0){perm = [ 0, 1 ]} - : (tensor<2x2xf32>) - ->tensor<2x2xf32> % 7 = "onnx.Add"(% 5, % 6) - : (tensor<2x2xf32>, tensor<2x2xf32>) - ->tensor<2x2xf32> - - return % 7 : tensor<2x2xf32> + return %7 : tensor<2x2xf32> } // CHECK-LABEL: func.func @test_inserted_qdq