Skip to content

Commit bdea6ce

Browse files
authored
Merge pull request #412 from xiaohanAMD/xiao.add_remove_qdq_aroundop
Xiao.add remove qdq aroundop
2 parents 8aaef29 + e199822 commit bdea6ce

17 files changed

+478
-57
lines changed

src/Compiler/OnnxToMlirPasses.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
8484
pm.addPass(onnx_mlir::createSimplifyShapeRelatedOpsPass(
8585
opts.enableQuarkQuantizedLegalization));
8686

87+
// Pass for removing Dq and Q around data movement in Dq->op->Q Ops chain
88+
if (opts.enableRemoveDqQAroundOp)
89+
pm.addPass(createQDQAroundOpOptONNXToONNXPass());
90+
91+
// Pass for removing redundant Dq->Q Ops chain
8792
// Passes for removing redundant concat, slice and cast QDQ Ops
8893
if (opts.enableRemoveDqQOp)
8994
pm.addPass(createQDQOptONNXToONNXPass());

src/Compiler/OnnxToMlirPasses.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ struct OnnxToMlirOptions {
1717
bool enableConvTransposeDecomposeToPhasedConv = false;
1818
bool enableConvTranspose1dDecomposeToPhasedConv = false;
1919
bool enableRemoveDqQOp = true;
20+
bool enableRemoveDqQAroundOp = true;
2021

2122
bool disableRecomposeOption = false;
2223
bool enableONNXHybridPass = true;

src/Dialect/ONNX/ONNXOps/OpHelper.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "llvm/ADT/TypeSwitch.h"
1818
#include "llvm/Support/Path.h"
1919

20+
#include "mlir/IR/BuiltinTypes.h"
2021
#include "src/Dialect/Mlir/IndexExpr.hpp"
2122
#include "src/Dialect/ONNX/DialectBuilder.hpp"
2223
#include "src/Dialect/ONNX/ONNXLayoutHelper.hpp"
@@ -315,6 +316,25 @@ ElementsAttr getElementAttributeFromONNXValue(Value value) {
315316
return nullptr;
316317
}
317318

319+
// compare two ElementsAttr, except for their internal buffer size
320+
bool compareValueFromElementAttribute(
321+
ElementsAttr &attr1, ElementsAttr &attr2) {
322+
if (attr1.getType() != attr2.getType()) {
323+
return false;
324+
}
325+
if (attr1.getNumElements() != attr2.getNumElements()) {
326+
return false;
327+
}
328+
auto it1 = attr1.getValues<mlir::Attribute>().begin();
329+
auto it2 = attr2.getValues<mlir::Attribute>().begin();
330+
for (; it1 != attr1.getValues<mlir::Attribute>().end(); ++it1, ++it2) {
331+
if (*it1 != *it2) {
332+
return false;
333+
}
334+
}
335+
return true;
336+
}
337+
318338
// Returns the ConstantOp which defines an MLIR Value or null.
319339
ONNXConstantOp getONNXConstantOp(Value value) {
320340
return mlir::dyn_cast_or_null<ONNXConstantOp>(value.getDefiningOp());
@@ -854,6 +874,54 @@ bool isIdentityReshape(
854874
return isIdentityReshape(inputTensor, outputTensor, dimAnalysis);
855875
}
856876

877+
bool isDequantQuantSame(
878+
mlir::ONNXDequantizeLinearOp dqOp, mlir::ONNXQuantizeLinearOp qOp) {
879+
880+
// 1. Check Attributes
881+
if (qOp.getAxis() != dqOp.getAxis())
882+
return false;
883+
if (qOp.getBlockSize() != dqOp.getBlockSize())
884+
return false;
885+
886+
// 2. Check zero-points
887+
auto zpAttr1 = getElementAttributeFromONNXValue(dqOp.getXZeroPoint());
888+
auto zpAttr2 = getElementAttributeFromONNXValue(qOp.getYZeroPoint());
889+
if (!zpAttr1 || !zpAttr2)
890+
return false;
891+
892+
if (!compareValueFromElementAttribute(zpAttr1, zpAttr2)) {
893+
return false;
894+
}
895+
// 3. Check Scales.
896+
auto scaleAttr1 = getElementAttributeFromONNXValue(dqOp.getXScale());
897+
auto scaleAttr2 = getElementAttributeFromONNXValue(qOp.getYScale());
898+
if (!scaleAttr1 || !scaleAttr2)
899+
return false;
900+
901+
if (!compareValueFromElementAttribute(scaleAttr1, scaleAttr2)) {
902+
return false;
903+
}
904+
905+
// 4. Check data type consistency of the entire DQ->Q chain.
906+
// The original quantized type before DQ must match the final quantized
907+
// type after Q.
908+
auto dqInTypeOp = dqOp.getX().getType();
909+
auto qOutTypeOp = qOp.getResult().getType();
910+
911+
if (auto dqInTensorType = mlir::dyn_cast<TensorType>(dqInTypeOp)) {
912+
if (auto qOutTensorType = mlir::dyn_cast<TensorType>(qOutTypeOp)) {
913+
if (qOutTensorType.getElementType() != dqInTensorType.getElementType()) {
914+
return false;
915+
}
916+
} else {
917+
return false;
918+
}
919+
} else {
920+
return false;
921+
}
922+
return true;
923+
}
924+
857925
//===----------------------------------------------------------------------===//
858926
// Support for location.
859927
//===----------------------------------------------------------------------===//

src/Dialect/ONNX/ONNXOps/OpHelper.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ void ArrayAttrIntVals(mlir::ArrayAttr a, mlir::SmallVectorImpl<int64_t> &i);
174174

175175
mlir::ElementsAttr getElementAttributeFromONNXValue(mlir::Value value);
176176

177+
bool compareValueFromElementAttribute(
178+
mlir::ElementsAttr &attr1, mlir::ElementsAttr &attr2);
179+
177180
mlir::ONNXConstantOp getONNXConstantOp(mlir::Value value);
178181

179182
// Obtain an array of int64_t values stored in ONNXConstantOp and append it to
@@ -397,6 +400,8 @@ bool isIdentityReshape(
397400
bool isIdentityReshape(mlir::Value input, mlir::Value output,
398401
const DimAnalysis *dimAnalysis = nullptr);
399402

403+
bool isDequantQuantSame(
404+
mlir::ONNXDequantizeLinearOp dqOp, mlir::ONNXQuantizeLinearOp qOp);
400405
//===----------------------------------------------------------------------===//
401406
// Support for location.
402407
//===----------------------------------------------------------------------===//

src/Dialect/ONNX/Transforms/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ add_onnx_mlir_rewriter(DecomposeConvTranspose1dPhased)
77

88
add_onnx_mlir_rewriter(ConstProp)
99
add_onnx_mlir_rewriter(ConvOpt)
10+
add_onnx_mlir_rewriter(QDQAroundOpOpt)
11+
add_onnx_mlir_rewriter(QDQOpt)
1012

1113
add_onnx_mlir_library(OMShapeInference
1214
ShapeInference.cpp
@@ -42,6 +44,7 @@ add_onnx_mlir_library(OMInstrumentONNX
4244

4345
add_onnx_mlir_library(OMONNXRewrite
4446
ConstProp.cpp
47+
QDQAroundOpOpt.cpp
4548
QDQOpt.cpp
4649
ConvOpt.cpp
4750
Decompose.cpp
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
//===- QDQAroundOpOpt.cpp - Remove DQ, Q operations around data movement ops
2+
//--------*- C++ -*-===//
3+
//
4+
// (c) Copyright 2022 - 2025 Advanced Micro Devices, Inc. All Rights Reserved.
5+
//
6+
//===----------------------------------------------------------------------===//
7+
8+
#include <cmath>
9+
#include <mlir/IR/IRMapping.h>
10+
#include <mlir/IR/Operation.h>
11+
#include <mlir/IR/PatternMatch.h>
12+
#include <mlir/Pass/Pass.h>
13+
#include <mlir/Transforms/DialectConversion.h>
14+
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
15+
#include <src/Dialect/ONNX/ONNXOps.hpp>
16+
#include <src/Dialect/ONNX/ONNXOps/OpHelper.hpp>
17+
18+
using namespace mlir;
19+
using namespace onnx_mlir;
20+
struct InputAndOutput {
21+
Value input;
22+
Value output;
23+
};
24+
25+
InputAndOutput getDataInputOutput(ONNXTransposeOp transposeOp) {
26+
return {transposeOp.getData(), transposeOp.getTransposed()};
27+
}
28+
InputAndOutput getDataInputOutput(ONNXUnsqueezeOp unsqueezeOp) {
29+
return {unsqueezeOp.getData(), unsqueezeOp.getExpanded()};
30+
}
31+
InputAndOutput getDataInputOutput(ONNXSqueezeOp squeezeOp) {
32+
return {squeezeOp.getData(), squeezeOp.getSqueezed()};
33+
}
34+
InputAndOutput getDataInputOutput(ONNXReshapeOp reshapeOp) {
35+
return {reshapeOp.getData(), reshapeOp.getReshaped()};
36+
}
37+
InputAndOutput getDataInputOutput(ONNXGatherOp gatherOp) {
38+
return {gatherOp.getData(), gatherOp.getOutput()};
39+
}
40+
InputAndOutput getDataInputOutput(ONNXSliceOp sliceOp) {
41+
return {sliceOp.getData(), sliceOp.getOutput()};
42+
}
43+
InputAndOutput getDataInputOutput(ONNXResizeOp resizeOp) {
44+
return {resizeOp.getX(), resizeOp.getY()};
45+
}
46+
InputAndOutput getDataInputOutput(ONNXFlattenOp flattenOp) {
47+
return {flattenOp.getInput(), flattenOp.getOutput()};
48+
}
49+
namespace {
50+
template <typename T>
51+
class RemoveQDQAroundOpPattern : public OpRewritePattern<T> {
52+
public:
53+
using OpRewritePattern<T>::OpRewritePattern;
54+
55+
LogicalResult matchAndRewrite(
56+
T op, PatternRewriter &rewriter) const override {
57+
if (llvm::isa<ONNXResizeOp>(op)) {
58+
auto &resizeOp = llvm::cast<ONNXResizeOp>(op);
59+
if (resizeOp.getMode() != "nearest") {
60+
return failure();
61+
}
62+
}
63+
InputAndOutput opIO = getDataInputOutput(op);
64+
65+
auto dqOp = opIO.input.getDefiningOp<ONNXDequantizeLinearOp>();
66+
// Only run this pass if Quantizelization is on tensor
67+
if (!dqOp || !isScalarConstantTensor(dqOp.getXScale()) ||
68+
!isScalarConstantTensor(dqOp.getXZeroPoint())) {
69+
return failure();
70+
}
71+
if (!opIO.output.hasOneUse()) {
72+
return failure();
73+
}
74+
75+
Operation *firstOp = *(opIO.output.getUsers().begin());
76+
if (auto qOp = dyn_cast<ONNXQuantizeLinearOp>(firstOp)) {
77+
if (!isScalarConstantTensor(qOp.getYScale()) ||
78+
!isScalarConstantTensor(qOp.getYZeroPoint())) {
79+
return failure();
80+
}
81+
if (!isDequantQuantSame(dqOp, qOp))
82+
return failure();
83+
84+
// Map dqOp inputs to dqOp's inputs
85+
IRMapping irMapping;
86+
irMapping.map(dqOp, dqOp.getX());
87+
88+
SmallVector<Value> newInputs;
89+
transform(op->getOperands(), std::back_inserter(newInputs),
90+
[&](Value operand) { return irMapping.lookupOrDefault(operand); });
91+
92+
auto newOp =
93+
rewriter.create<T>(op.getLoc(), TypeRange{qOp.getResult().getType()},
94+
ValueRange{newInputs}, op->getAttrs());
95+
rewriter.replaceOp(qOp, newOp.getResult());
96+
return success();
97+
}
98+
};
99+
};
100+
struct QDQAroundOpOptONNXToONNXPass
101+
: public PassWrapper<QDQAroundOpOptONNXToONNXPass,
102+
OperationPass<func::FuncOp>> {
103+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(QDQAroundOpOptONNXToONNXPass)
104+
StringRef getArgument() const override {
105+
return "qdq-around-op-opt-onnx-to-onnx";
106+
}
107+
StringRef getDescription() const override {
108+
return "Remove QDQ around ops if safe.";
109+
}
110+
111+
void runOnOperation() override {
112+
auto function = getOperation();
113+
auto *ctx = &getContext();
114+
RewritePatternSet patterns(ctx);
115+
// ONNXReduceSumOp is expecting high precision value, it failed to compile
116+
// during applying this pass, so for now there is no dq, q removal around
117+
// ReduceSum
118+
patterns.add<RemoveQDQAroundOpPattern<ONNXTransposeOp>,
119+
RemoveQDQAroundOpPattern<ONNXUnsqueezeOp>,
120+
RemoveQDQAroundOpPattern<ONNXSqueezeOp>,
121+
RemoveQDQAroundOpPattern<ONNXReshapeOp>,
122+
RemoveQDQAroundOpPattern<ONNXResizeOp>,
123+
RemoveQDQAroundOpPattern<ONNXGatherOp>,
124+
RemoveQDQAroundOpPattern<ONNXSliceOp>,
125+
RemoveQDQAroundOpPattern<ONNXFlattenOp>>(patterns.getContext());
126+
if (failed(applyPatternsGreedily(function, std::move(patterns))))
127+
signalPassFailure();
128+
}
129+
};
130+
} // namespace
131+
132+
namespace onnx_mlir {
133+
std::unique_ptr<mlir::Pass> createQDQAroundOpOptONNXToONNXPass() {
134+
return std::make_unique<QDQAroundOpOptONNXToONNXPass>();
135+
}
136+
} // namespace onnx_mlir

src/Dialect/ONNX/Transforms/QDQOpt.cpp

Lines changed: 1 addition & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -4,37 +4,20 @@
44
//
55
//===----------------------------------------------------------------------===//
66

7-
#include "mlir/IR/Attributes.h"
8-
#include "mlir/IR/BuiltinTypes.h"
9-
#include "mlir/IR/Operation.h"
107
#include "mlir/IR/PatternMatch.h"
118
#include "mlir/Pass/Pass.h"
129
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1310
#include "src/Dialect/ONNX/ONNXOps.hpp"
1411
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
1512
#include "src/Pass/Passes.hpp"
1613

17-
#include "llvm/ADT/STLExtras.h"
18-
#include "llvm/ADT/SmallSet.h"
1914
#include <cmath>
2015

2116
using namespace mlir;
2217
using namespace onnx_mlir;
2318

2419
namespace {
2520

26-
//===----------------------------------------------------------------------===//
27-
// Helper Functions
28-
//===----------------------------------------------------------------------===//
29-
30-
static ElementsAttr getElementAttributeFromConstant(Value val) {
31-
if (!val)
32-
return nullptr;
33-
if (auto constOp = val.getDefiningOp<ONNXConstantOp>())
34-
return mlir::dyn_cast<ElementsAttr>(constOp.getValueAttr());
35-
return nullptr;
36-
}
37-
3821
//===----------------------------------------------------------------------===//
3922
// Pattern to remove QDQ pairs
4023
//===----------------------------------------------------------------------===//
@@ -47,47 +30,8 @@ struct FoldQDQPattern : public OpRewritePattern<ONNXQuantizeLinearOp> {
4730
auto dqOp = qOp.getX().getDefiningOp<ONNXDequantizeLinearOp>();
4831
if (!dqOp)
4932
return failure();
50-
51-
// 1. Check Attributes
52-
if (qOp.getAxis() != dqOp.getAxis())
53-
return failure();
54-
if (qOp.getBlockSize() != dqOp.getBlockSize())
55-
return failure();
56-
57-
// 2. Check zero-points
58-
auto zpAttr1 = getElementAttributeFromConstant(dqOp.getXZeroPoint());
59-
auto zpAttr2 = getElementAttributeFromConstant(qOp.getYZeroPoint());
60-
if (!zpAttr1 && !zpAttr2)
61-
return failure();
62-
if (zpAttr1 != zpAttr2)
63-
return failure();
64-
65-
// 3. Check Scales.
66-
auto scaleAttr1 = getElementAttributeFromConstant(dqOp.getXScale());
67-
auto scaleAttr2 = getElementAttributeFromConstant(qOp.getYScale());
68-
if (!scaleAttr1 && !scaleAttr2)
69-
return failure();
70-
if (scaleAttr1 != scaleAttr2)
71-
return failure();
72-
73-
// 4. Check data type consistency of the entire DQ->Q chain.
74-
// The original quantized type before DQ must match the final quantized
75-
// type after Q.
76-
auto dqInTypeOp = dqOp.getX().getType();
77-
auto qOutTypeOp = qOp.getResult().getType();
78-
79-
if (auto dqInTensorType = dqInTypeOp.dyn_cast<TensorType>()) {
80-
if (auto qOutTensorType = qOutTypeOp.dyn_cast<TensorType>()) {
81-
if (qOutTensorType.getElementType() !=
82-
dqInTensorType.getElementType()) {
83-
return failure();
84-
}
85-
} else {
86-
return failure();
87-
}
88-
} else {
33+
if (!isDequantQuantSame(dqOp, qOp))
8934
return failure();
90-
}
9135
rewriter.replaceOp(qOp, dqOp.getX());
9236
return success();
9337
}

src/Pass/Passes.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ void configureConstPropONNXToONNXPass(bool roundFPToInt, int expansionBound,
5454
llvm::ArrayRef<std::string> disabledPatterns, bool constantPropIsDisabled);
5555

5656
std::unique_ptr<mlir::Pass> createConstPropONNXToONNXPass();
57+
std::unique_ptr<mlir::Pass> createQDQAroundOpOptONNXToONNXPass();
5758

5859
std::unique_ptr<mlir::Pass> createQDQOptONNXToONNXPass();
5960

src/Tools/onnx-mlir-opt/RegisterPasses.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,14 @@ void registerOMPasses(int optLevel) {
7171
return createQDQOptONNXToONNXPass();
7272
});
7373

74+
mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
75+
return createQDQAroundOpOptONNXToONNXPass();
76+
});
77+
78+
mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
79+
return createQDQOptONNXToONNXPass();
80+
});
81+
7482
mlir::registerPass(
7583
[]() -> std::unique_ptr<mlir::Pass> { return createInstrumentPass(); });
7684

0 commit comments

Comments
 (0)