Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/Compiler/OnnxToMlirPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
if (opts.enableRemoveDqQOp)
pm.addPass(createQDQOptONNXToONNXPass());

if (opts.enableAddQDQOp)
pm.addPass(createMissingQDQAroundOpOptONNXToONNXPass());

// One more call to ONNX shape inference/canonicalization/... to update
// shape if possible.
if (opts.enableONNXHybridPass) {
Expand Down
1 change: 1 addition & 0 deletions src/Compiler/OnnxToMlirPasses.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ struct OnnxToMlirOptions {
bool enableRemoveDqQOp = true;
bool enableRemoveDqQAroundOp = true;
bool enableRemoveBinary = false;
bool enableAddQDQOp = false;

bool disableRecomposeOption = false;
bool enableONNXHybridPass = true;
Expand Down
187 changes: 187 additions & 0 deletions src/Dialect/ONNX/Transforms/AddQDQOpt.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@

// (c) Copyright 2022 - 2025 Advanced Micro Devices, Inc. All Rights Reserved.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
#include "src/Pass/Passes.hpp"
#include "llvm/ADT/STLExtras.h"
#include <cmath>
#include <optional>
#include <variant>

#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "src/Dialect/ONNX/ONNXOps.hpp"

using namespace mlir;
using namespace onnx_mlir;

namespace {

struct AddQDQAroundOp
: public PassWrapper<AddQDQAroundOp, OperationPass<func::FuncOp>> {

MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddQDQAroundOp)

StringRef getArgument() const final { return "add-qdq-around-op"; }
StringRef getDescription() const final {
return "Add Q ,DQ around nodes which are missing them";
}

Type extractZeroPointType(OpBuilder &builder, func::FuncOp &func) {
Type zpElemType = builder.getIntegerType(8); // default int8

func.walk([&](Operation *op) -> WalkResult {
if (auto q = mlir::dyn_cast<ONNXQuantizeLinearOp>(op)) {
Value zp = q.getYZeroPoint();
if (auto zpShaped = llvm::dyn_cast<ShapedType>(zp.getType())) {
zpElemType = zpShaped.getElementType();
return WalkResult::interrupt();
}
} else if (auto dq = mlir::dyn_cast<ONNXDequantizeLinearOp>(op)) {
Value zp = dq.getXZeroPoint();
if (auto zpShaped = llvm::dyn_cast<ShapedType>(zp.getType())) {
zpElemType = zpShaped.getElementType();
return WalkResult::interrupt();
}
}
return WalkResult::advance();
});
return zpElemType;
}

bool isDMAOp(Operation *op) {
return isa<ONNXReshapeOp, ONNXFlattenOp, ONNXSqueezeOp, ONNXUnsqueezeOp,
ONNXExpandOp, ONNXTransposeOp, ONNXIdentityOp, ONNXSliceOp,
ONNXConcatOp, ONNXSplitOp, ONNXGatherOp, ONNXGatherElementsOp,
ONNXGatherNDOp, ONNXScatterOp, ONNXScatterElementsOp, ONNXScatterNDOp,
ONNXPadOp, ONNXCastOp, ONNXShapeOp, ONNXConstantOfShapeOp, ONNXTileOp,
ONNXDepthToSpaceOp, ONNXSpaceToDepthOp, ONNXResizeOp>(op);
}

void runOnOperation() override {
func::FuncOp func = getOperation();
MLIRContext *ctx = &getContext();
OpBuilder builder(ctx);

Block &entryBlock = func.getBody().front();
Location entryLoc = func.getLoc();

RankedTensorType scaleType =
RankedTensorType::get({}, builder.getF32Type());
DenseElementsAttr scaleAttr = DenseElementsAttr::get(scaleType, {1.0f});
OperationState scaleState(entryLoc, "onnx.Constant");
scaleState.addAttribute("value", scaleAttr);
scaleState.addTypes(scaleAttr.getType());
Operation *scaleOp = Operation::create(scaleState);
entryBlock.getOperations().insert(entryBlock.begin(), scaleOp);
Value scaleVal = scaleOp->getResult(0);

Type zpEleType = extractZeroPointType(builder, func);
RankedTensorType zpType = RankedTensorType::get({}, zpEleType);
DenseElementsAttr zpAttr;
if (isa<IntegerType>(zpEleType)) {
auto intType = mlir::dyn_cast<IntegerType>(zpEleType);
unsigned width = intType.getWidth();
bool isSigned = intType.isSignedInteger();

if (width == 8) {
if (isSigned) {
zpAttr = DenseElementsAttr::get(zpType, {static_cast<int8_t>(0)});
} else {
zpAttr = DenseElementsAttr::get(zpType, {static_cast<uint8_t>(0)});
}
} else if (width == 16) {
if (isSigned) {
zpAttr = DenseElementsAttr::get(zpType, {static_cast<int16_t>(0)});
} else {
zpAttr = DenseElementsAttr::get(zpType, {static_cast<uint16_t>(0)});
}
} else if (width == 32) {
if (isSigned) {
zpAttr = DenseElementsAttr::get(zpType, {static_cast<int32_t>(0)});
} else {
zpAttr = DenseElementsAttr::get(zpType, {static_cast<uint32_t>(0)});
}
} else {
// fallback: default int8 zero-point
zpAttr = DenseElementsAttr::get(zpType, {static_cast<int8_t>(0)});
}
} else {
// fallback if not integer
zpAttr = DenseElementsAttr::get(zpType, {static_cast<int8_t>(0)});
}

OperationState zpState(entryLoc, "onnx.Constant");
zpState.addAttribute("value", zpAttr);
zpState.addTypes(zpAttr.getType());
Operation *zpOp = Operation::create(zpState);
entryBlock.getOperations().insert(entryBlock.begin(), zpOp);
Value zpVal = zpOp->getResult(0);

llvm::SmallDenseMap<Value, Value> producerToDQ;

for (Operation &opRef : llvm::make_early_inc_range(func.getOps())) {
Operation *op = &opRef;

if ((isa<ONNXConstantOp, ONNXQuantizeLinearOp, ONNXDequantizeLinearOp>(
op)) ||
!(isDMAOp(op)))
continue;

Location loc = op->getLoc();

for (Value operand : op->getOperands()) {
Operation *def = operand.getDefiningOp();
if (def && isa<ONNXQuantizeLinearOp, ONNXDequantizeLinearOp>(def))
continue;

auto it = producerToDQ.find(operand);
if (it != producerToDQ.end()) {
op->replaceUsesOfWith(operand, it->second);
continue;
}

builder.setInsertionPoint(op);

ShapedType inShaped = llvm::dyn_cast<ShapedType>(operand.getType());

Type qResultType = operand.getType();
if (isa<FloatType>(operand.getType()) ||
(inShaped && isa<FloatType>(inShaped.getElementType()))) {
if (inShaped)
qResultType = RankedTensorType::get(inShaped.getShape(), zpEleType);
else
qResultType = RankedTensorType::get({}, zpEleType);

auto q = builder.create<ONNXQuantizeLinearOp>(
loc, qResultType, operand, scaleVal, zpVal);
auto dq = builder.create<ONNXDequantizeLinearOp>(
loc, operand.getType(), q.getResult(), scaleVal, zpVal);

producerToDQ.try_emplace(operand, dq.getResult());
op->replaceUsesOfWith(operand, dq.getResult());
}
}
}
}
};
} // namespace

namespace onnx_mlir {
std::unique_ptr<mlir::Pass> createMissingQDQAroundOpOptONNXToONNXPass() {
return std::make_unique<AddQDQAroundOp>();
}
} // namespace onnx_mlir
3 changes: 2 additions & 1 deletion src/Dialect/ONNX/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ add_onnx_mlir_library(OMONNXRewrite
ConstProp.cpp
QDQAroundOpOpt.cpp
QDQOpt.cpp
DQBinaryQOpt.cpp
DQBinaryQOpt.cpp
AddQDQOpt.cpp
ConvOpt.cpp
Decompose.cpp
DecomposeEinsum.cpp
Expand Down
1 change: 1 addition & 0 deletions src/Pass/Passes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ std::unique_ptr<mlir::Pass> createQDQAroundOpOptONNXToONNXPass();

std::unique_ptr<mlir::Pass> createQDQOptONNXToONNXPass();
std::unique_ptr<mlir::Pass> createFoldDQBinaryQPass();
std::unique_ptr<mlir::Pass> createMissingQDQAroundOpOptONNXToONNXPass();

/// Pass for instrument the ops in specific stage.
std::unique_ptr<mlir::Pass> createInstrumentPass();
Expand Down
4 changes: 4 additions & 0 deletions src/Tools/onnx-mlir-opt/RegisterPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ void registerOMPasses(int optLevel) {
return createFoldDQBinaryQPass();
});

mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return createMissingQDQAroundOpOptONNXToONNXPass();
});

mlir::registerPass(
[]() -> std::unique_ptr<mlir::Pass> { return createInstrumentPass(); });

Expand Down
33 changes: 33 additions & 0 deletions test/mlir/onnx/onnx_add_qdq.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: onnx-mlir-opt --add-qdq-around-op %s | FileCheck %s

func.func @test_inserted_qdq(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
%shape = onnx.Constant {value = dense<[2, 2]> : tensor<2xi64>} : tensor<2xi64>
%cst = onnx.Constant {value = dense<1.0> : tensor<2x2xf32>} : tensor<2x2xf32>
%init = onnx.Constant dense<2.0> : tensor<f32>
%cst_qdq_zp = onnx.Constant dense<0> : tensor<i16>
%cst_qdq_s = onnx.Constant dense<1.52590219E-5> : tensor<f32>
%0 = "onnx.Add"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
%1 = "onnx.Mul"(%0, %init) : (tensor<2x2xf32>, tensor<f32>) -> tensor<2x2xf32>
%2 = "onnx.Reshape"(%1, %shape) : (tensor<2x2xf32>, tensor<2xi64>) -> tensor<2x2xf32>
%3 = "onnx.QuantizeLinear"(%0, %cst_qdq_s, %cst_qdq_zp) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<2x2xf32>, tensor<f32>, tensor<i16>) -> tensor<2x2xi16>
%4 = "onnx.DequantizeLinear"(%3, %cst_qdq_s, %cst_qdq_zp) {axis = 1 : si64, block_size = 0 : si64} : (tensor<2x2xi16>, tensor<f32>, tensor<i16>) -> tensor<2x2xf32>
%5 = "onnx.Reshape"(%4, %shape) : (tensor<2x2xf32>, tensor<2xi64>) -> tensor<2x2xf32>
%6 = "onnx.Transpose"(%0) {perm = [0, 1]} : (tensor<2x2xf32>) -> tensor<2x2xf32>
%7 = "onnx.Add"(%5, %6) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>

return %7 : tensor<2x2xf32>
}

// CHECK-LABEL: func.func @test_inserted_qdq
// CHECK: onnx.Add
// CHECK: onnx.Mul
// CHECK: onnx.QuantizeLinear
// CHECK: onnx.DequantizeLinear
// CHECK: onnx.Reshape
// CHECK: onnx.QuantizeLinear
// CHECK: onnx.DequantizeLinear
// CHECK: onnx.Reshape
// CHECK: onnx.QuantizeLinear
// CHECK: onnx.DequantizeLinear
// CHECK: onnx.Transpose
// CHECK: onnx.Add