diff --git a/src/Compiler/OnnxToMlirPasses.cpp b/src/Compiler/OnnxToMlirPasses.cpp index 291ca8e464..ccba81eb6c 100644 --- a/src/Compiler/OnnxToMlirPasses.cpp +++ b/src/Compiler/OnnxToMlirPasses.cpp @@ -97,6 +97,9 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU, if (opts.enableRemoveDqQOp) pm.addPass(createQDQOptONNXToONNXPass()); + if (opts.enableAddQDQOp) + pm.addPass(createMissingQDQAroundOpOptONNXToONNXPass()); + // One more call to ONNX shape inference/canonicalization/... to update // shape if possible. if (opts.enableONNXHybridPass) { diff --git a/src/Compiler/OnnxToMlirPasses.hpp b/src/Compiler/OnnxToMlirPasses.hpp index a7532ec926..6c2d0a7b4a 100644 --- a/src/Compiler/OnnxToMlirPasses.hpp +++ b/src/Compiler/OnnxToMlirPasses.hpp @@ -19,6 +19,7 @@ struct OnnxToMlirOptions { bool enableRemoveDqQOp = true; bool enableRemoveDqQAroundOp = true; bool enableRemoveBinary = false; + bool enableAddQDQOp = false; bool disableRecomposeOption = false; bool enableONNXHybridPass = true; diff --git a/src/Dialect/ONNX/Transforms/AddQDQOpt.cpp b/src/Dialect/ONNX/Transforms/AddQDQOpt.cpp new file mode 100644 index 0000000000..25685e7657 --- /dev/null +++ b/src/Dialect/ONNX/Transforms/AddQDQOpt.cpp @@ -0,0 +1,187 @@ + +// (c) Copyright 2022 - 2025 Advanced Micro Devices, Inc. All Rights Reserved. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" +#include "src/Dialect/ONNX/ONNXOps.hpp" +#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" +#include "src/Pass/Passes.hpp" +#include "llvm/ADT/STLExtras.h" +#include +#include +#include + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; +using namespace onnx_mlir; + +namespace { + +struct AddQDQAroundOp + : public PassWrapper> { + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddQDQAroundOp) + + StringRef getArgument() const final { return "add-qdq-around-op"; } + StringRef getDescription() const final { + return "Add Q ,DQ around nodes which are missing them"; + } + + Type extractZeroPointType(OpBuilder &builder, func::FuncOp &func) { + Type zpElemType = builder.getIntegerType(8); // default int8 + + func.walk([&](Operation *op) -> WalkResult { + if (auto q = mlir::dyn_cast(op)) { + Value zp = q.getYZeroPoint(); + if (auto zpShaped = llvm::dyn_cast(zp.getType())) { + zpElemType = zpShaped.getElementType(); + return WalkResult::interrupt(); + } + } else if (auto dq = mlir::dyn_cast(op)) { + Value zp = dq.getXZeroPoint(); + if (auto zpShaped = llvm::dyn_cast(zp.getType())) { + zpElemType = zpShaped.getElementType(); + return WalkResult::interrupt(); + } + } + return WalkResult::advance(); + }); + return zpElemType; + } + + bool isDMAOp(Operation *op) { + return isa(op); + } + + void runOnOperation() override { + func::FuncOp func = getOperation(); + MLIRContext *ctx = &getContext(); + OpBuilder builder(ctx); + + Block &entryBlock = func.getBody().front(); + Location entryLoc = func.getLoc(); + + RankedTensorType scaleType = + RankedTensorType::get({}, builder.getF32Type()); + DenseElementsAttr scaleAttr = DenseElementsAttr::get(scaleType, {1.0f}); + OperationState scaleState(entryLoc, "onnx.Constant"); + scaleState.addAttribute("value", scaleAttr); + scaleState.addTypes(scaleAttr.getType()); + Operation *scaleOp = Operation::create(scaleState); + entryBlock.getOperations().insert(entryBlock.begin(), scaleOp); + Value scaleVal = scaleOp->getResult(0); + + Type zpEleType = extractZeroPointType(builder, func); + RankedTensorType zpType = RankedTensorType::get({}, zpEleType); + DenseElementsAttr zpAttr; + if (isa(zpEleType)) { + auto intType = mlir::dyn_cast(zpEleType); + unsigned width = intType.getWidth(); + bool isSigned = intType.isSignedInteger(); + + if (width == 8) { + if (isSigned) { + zpAttr = DenseElementsAttr::get(zpType, {static_cast(0)}); + } else { + zpAttr = DenseElementsAttr::get(zpType, {static_cast(0)}); + } + } else if (width == 16) { + if (isSigned) { + zpAttr = DenseElementsAttr::get(zpType, {static_cast(0)}); + } else { + zpAttr = DenseElementsAttr::get(zpType, {static_cast(0)}); + } + } else if (width == 32) { + if (isSigned) { + zpAttr = DenseElementsAttr::get(zpType, {static_cast(0)}); + } else { + zpAttr = DenseElementsAttr::get(zpType, {static_cast(0)}); + } + } else { + // fallback: default int8 zero-point + zpAttr = DenseElementsAttr::get(zpType, {static_cast(0)}); + } + } else { + // fallback if not integer + zpAttr = DenseElementsAttr::get(zpType, {static_cast(0)}); + } + + OperationState zpState(entryLoc, "onnx.Constant"); + zpState.addAttribute("value", zpAttr); + zpState.addTypes(zpAttr.getType()); + Operation *zpOp = Operation::create(zpState); + entryBlock.getOperations().insert(entryBlock.begin(), zpOp); + Value zpVal = zpOp->getResult(0); + + llvm::SmallDenseMap producerToDQ; + + for (Operation &opRef : llvm::make_early_inc_range(func.getOps())) { + Operation *op = &opRef; + + if ((isa( + op)) || + !(isDMAOp(op))) + continue; + + Location loc = op->getLoc(); + + for (Value operand : op->getOperands()) { + Operation *def = operand.getDefiningOp(); + if (def && isa(def)) + continue; + + auto it = producerToDQ.find(operand); + if (it != producerToDQ.end()) { + op->replaceUsesOfWith(operand, it->second); + continue; + } + + builder.setInsertionPoint(op); + + ShapedType inShaped = llvm::dyn_cast(operand.getType()); + + Type qResultType = operand.getType(); + if (isa(operand.getType()) || + (inShaped && isa(inShaped.getElementType()))) { + if (inShaped) + qResultType = RankedTensorType::get(inShaped.getShape(), zpEleType); + else + qResultType = RankedTensorType::get({}, zpEleType); + + auto q = builder.create( + loc, qResultType, operand, scaleVal, zpVal); + auto dq = builder.create( + loc, operand.getType(), q.getResult(), scaleVal, zpVal); + + producerToDQ.try_emplace(operand, dq.getResult()); + op->replaceUsesOfWith(operand, dq.getResult()); + } + } + } + } +}; +} // namespace + +namespace onnx_mlir { +std::unique_ptr createMissingQDQAroundOpOptONNXToONNXPass() { + return std::make_unique(); +} +} // namespace onnx_mlir \ No newline at end of file diff --git a/src/Dialect/ONNX/Transforms/CMakeLists.txt b/src/Dialect/ONNX/Transforms/CMakeLists.txt index 7862443935..2ba5887acc 100644 --- a/src/Dialect/ONNX/Transforms/CMakeLists.txt +++ b/src/Dialect/ONNX/Transforms/CMakeLists.txt @@ -44,7 +44,8 @@ add_onnx_mlir_library(OMONNXRewrite ConstProp.cpp QDQAroundOpOpt.cpp QDQOpt.cpp - DQBinaryQOpt.cpp + DQBinaryQOpt.cpp + AddQDQOpt.cpp ConvOpt.cpp Decompose.cpp DecomposeEinsum.cpp diff --git a/src/Pass/Passes.hpp b/src/Pass/Passes.hpp index 8e4855e6dc..f0420d4b63 100644 --- a/src/Pass/Passes.hpp +++ b/src/Pass/Passes.hpp @@ -58,6 +58,7 @@ std::unique_ptr createQDQAroundOpOptONNXToONNXPass(); std::unique_ptr createQDQOptONNXToONNXPass(); std::unique_ptr createFoldDQBinaryQPass(); +std::unique_ptr createMissingQDQAroundOpOptONNXToONNXPass(); /// Pass for instrument the ops in specific stage. std::unique_ptr createInstrumentPass(); diff --git a/src/Tools/onnx-mlir-opt/RegisterPasses.cpp b/src/Tools/onnx-mlir-opt/RegisterPasses.cpp index ba6635b466..50fd89cb80 100644 --- a/src/Tools/onnx-mlir-opt/RegisterPasses.cpp +++ b/src/Tools/onnx-mlir-opt/RegisterPasses.cpp @@ -83,6 +83,10 @@ void registerOMPasses(int optLevel) { return createFoldDQBinaryQPass(); }); + mlir::registerPass([]() -> std::unique_ptr { + return createMissingQDQAroundOpOptONNXToONNXPass(); + }); + mlir::registerPass( []() -> std::unique_ptr { return createInstrumentPass(); }); diff --git a/test/mlir/onnx/onnx_add_qdq.mlir b/test/mlir/onnx/onnx_add_qdq.mlir new file mode 100644 index 0000000000..8a0c63ee19 --- /dev/null +++ b/test/mlir/onnx/onnx_add_qdq.mlir @@ -0,0 +1,33 @@ +// RUN: onnx-mlir-opt --add-qdq-around-op %s | FileCheck %s + +func.func @test_inserted_qdq(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + %shape = onnx.Constant {value = dense<[2, 2]> : tensor<2xi64>} : tensor<2xi64> + %cst = onnx.Constant {value = dense<1.0> : tensor<2x2xf32>} : tensor<2x2xf32> + %init = onnx.Constant dense<2.0> : tensor + %cst_qdq_zp = onnx.Constant dense<0> : tensor + %cst_qdq_s = onnx.Constant dense<1.52590219E-5> : tensor + %0 = "onnx.Add"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + %1 = "onnx.Mul"(%0, %init) : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + %2 = "onnx.Reshape"(%1, %shape) : (tensor<2x2xf32>, tensor<2xi64>) -> tensor<2x2xf32> + %3 = "onnx.QuantizeLinear"(%0, %cst_qdq_s, %cst_qdq_zp) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<2x2xf32>, tensor, tensor) -> tensor<2x2xi16> + %4 = "onnx.DequantizeLinear"(%3, %cst_qdq_s, %cst_qdq_zp) {axis = 1 : si64, block_size = 0 : si64} : (tensor<2x2xi16>, tensor, tensor) -> tensor<2x2xf32> + %5 = "onnx.Reshape"(%4, %shape) : (tensor<2x2xf32>, tensor<2xi64>) -> tensor<2x2xf32> + %6 = "onnx.Transpose"(%0) {perm = [0, 1]} : (tensor<2x2xf32>) -> tensor<2x2xf32> + %7 = "onnx.Add"(%5, %6) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + + return %7 : tensor<2x2xf32> +} + +// CHECK-LABEL: func.func @test_inserted_qdq +// CHECK: onnx.Add +// CHECK: onnx.Mul +// CHECK: onnx.QuantizeLinear +// CHECK: onnx.DequantizeLinear +// CHECK: onnx.Reshape +// CHECK: onnx.QuantizeLinear +// CHECK: onnx.DequantizeLinear +// CHECK: onnx.Reshape +// CHECK: onnx.QuantizeLinear +// CHECK: onnx.DequantizeLinear +// CHECK: onnx.Transpose +// CHECK: onnx.Add