Skip to content

Commit f1ecdb1

Browse files
authored
Merge pull request #408 from SushmitaThakallapalli1980/feature/onnx-to-tosa
Created passes for removing redundant cases of DQ-Concat-Q, DQ-Cast-Q, DQ-Slice-Q and DQ-Q.
2 parents 4e9c1be + ab6bb16 commit f1ecdb1

File tree

10 files changed

+372
-0
lines changed

10 files changed

+372
-0
lines changed

src/Compiler/CompilerPasses.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,10 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
141141
pm.addPass(onnx_mlir::createSimplifyShapeRelatedOpsPass(
142142
opts.enableQuarkQuantizedLegalization));
143143

144+
// Passes for removing redundant concat, slice and cast QDQ Ops
145+
if (opts.enableRemoveDqQOp)
146+
pm.addPass(createQDQOptONNXToONNXPass());
147+
144148
// One more call to ONNX shape inference/canonicalization/... to update
145149
// shape if possible.
146150
if (enableONNXHybridPass) {

src/Compiler/CompilerPasses.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ struct OnnxToMlirOptions {
3030
bool enableConvTransposeDecompose = false;
3131
bool enableConvTransposeDecomposeToPhasedConv = false;
3232
bool enableConvTranspose1dDecomposeToPhasedConv = false;
33+
bool enableRemoveDqQOp = true;
3334
};
3435

3536
void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,

src/Dialect/ONNX/Transforms/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ add_onnx_mlir_rewriter(DecomposeConvTranspose1dPhased)
77

88
add_onnx_mlir_rewriter(ConstProp)
99
add_onnx_mlir_rewriter(ConvOpt)
10+
add_onnx_mlir_rewriter(QDQOpt)
1011

1112
add_onnx_mlir_library(OMShapeInference
1213
ShapeInference.cpp
@@ -42,6 +43,7 @@ add_onnx_mlir_library(OMInstrumentONNX
4243

4344
add_onnx_mlir_library(OMONNXRewrite
4445
ConstProp.cpp
46+
QDQOpt.cpp
4547
ConvOpt.cpp
4648
Decompose.cpp
4749
DecomposeEinsum.cpp
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
//===- QDQOpt.cpp - Remove QDQ operations --------*- C++ -*-===//
2+
//
3+
// (c) Copyright 2022 - 2025 Advanced Micro Devices, Inc. All Rights Reserved.
4+
//
5+
//===----------------------------------------------------------------------===//
6+
7+
#include "mlir/IR/Attributes.h"
8+
#include "mlir/IR/BuiltinTypes.h"
9+
#include "mlir/IR/Operation.h"
10+
#include "mlir/IR/PatternMatch.h"
11+
#include "mlir/Pass/Pass.h"
12+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13+
#include "src/Dialect/ONNX/ONNXOps.hpp"
14+
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
15+
#include "src/Pass/Passes.hpp"
16+
17+
#include "llvm/ADT/STLExtras.h"
18+
#include "llvm/ADT/SmallSet.h"
19+
#include <cmath>
20+
21+
using namespace mlir;
22+
using namespace onnx_mlir;
23+
24+
namespace {
25+
26+
//===----------------------------------------------------------------------===//
27+
// Helper Functions
28+
//===----------------------------------------------------------------------===//
29+
30+
static ElementsAttr getElementAttributeFromConstant(Value val) {
31+
if (!val)
32+
return nullptr;
33+
if (auto constOp = val.getDefiningOp<ONNXConstantOp>())
34+
return mlir::dyn_cast<ElementsAttr>(constOp.getValueAttr());
35+
return nullptr;
36+
}
37+
38+
//===----------------------------------------------------------------------===//
39+
// Pattern to remove QDQ pairs
40+
//===----------------------------------------------------------------------===//
41+
42+
struct FoldQDQPattern : public OpRewritePattern<ONNXQuantizeLinearOp> {
43+
using OpRewritePattern<ONNXQuantizeLinearOp>::OpRewritePattern;
44+
LogicalResult matchAndRewrite(
45+
ONNXQuantizeLinearOp qOp, PatternRewriter &rewriter) const override {
46+
47+
auto dqOp = qOp.getX().getDefiningOp<ONNXDequantizeLinearOp>();
48+
if (!dqOp)
49+
return failure();
50+
51+
// 1. Check Attributes
52+
if (qOp.getAxis() != dqOp.getAxis())
53+
return failure();
54+
if (qOp.getBlockSize() != dqOp.getBlockSize())
55+
return failure();
56+
57+
// 2. Check zero-points
58+
auto zpAttr1 = getElementAttributeFromConstant(dqOp.getXZeroPoint());
59+
auto zpAttr2 = getElementAttributeFromConstant(qOp.getYZeroPoint());
60+
if (!zpAttr1 && !zpAttr2)
61+
return failure();
62+
if (zpAttr1 != zpAttr2)
63+
return failure();
64+
65+
// 3. Check Scales.
66+
auto scaleAttr1 = getElementAttributeFromConstant(dqOp.getXScale());
67+
auto scaleAttr2 = getElementAttributeFromConstant(qOp.getYScale());
68+
if (!scaleAttr1 && !scaleAttr2)
69+
return failure();
70+
if (scaleAttr1 != scaleAttr2)
71+
return failure();
72+
73+
// 4. Check data type consistency of the entire DQ->Q chain.
74+
// The original quantized type before DQ must match the final quantized
75+
// type after Q.
76+
auto dqInTypeOp = dqOp.getX().getType();
77+
auto qOutTypeOp = qOp.getResult().getType();
78+
79+
if (auto dqInTensorType = dqInTypeOp.dyn_cast<TensorType>()) {
80+
if (auto qOutTensorType = qOutTypeOp.dyn_cast<TensorType>()) {
81+
if (qOutTensorType.getElementType() !=
82+
dqInTensorType.getElementType()) {
83+
return failure();
84+
}
85+
} else {
86+
return failure();
87+
}
88+
} else {
89+
return failure();
90+
}
91+
rewriter.replaceOp(qOp, dqOp.getX());
92+
return success();
93+
}
94+
};
95+
96+
//===----------------------------------------------------------------------===//
97+
// Pass to run QDQ removal
98+
//===----------------------------------------------------------------------===//
99+
100+
struct QDQOptONNXToONNXPass
101+
: public PassWrapper<QDQOptONNXToONNXPass, OperationPass<func::FuncOp>> {
102+
103+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(QDQOptONNXToONNXPass)
104+
StringRef getArgument() const override { return "dqq-opt-onnx-to-onnx"; }
105+
StringRef getDescription() const override {
106+
return "Remove DqQ ops and surrounding DqQ if safe.";
107+
}
108+
109+
void runOnOperation() override {
110+
auto function = getOperation();
111+
RewritePatternSet patterns(&getContext());
112+
patterns.add<FoldQDQPattern>(&getContext());
113+
if (failed(applyPatternsGreedily(function, std::move(patterns))))
114+
signalPassFailure();
115+
}
116+
};
117+
} // namespace
118+
119+
namespace onnx_mlir {
120+
std::unique_ptr<mlir::Pass> createQDQOptONNXToONNXPass() {
121+
return std::make_unique<QDQOptONNXToONNXPass>();
122+
}
123+
} // namespace onnx_mlir

src/Pass/Passes.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ void configureConstPropONNXToONNXPass(bool roundFPToInt, int expansionBound,
5555

5656
std::unique_ptr<mlir::Pass> createConstPropONNXToONNXPass();
5757

58+
std::unique_ptr<mlir::Pass> createQDQOptONNXToONNXPass();
59+
5860
/// Pass for instrument the ops in specific stage.
5961
std::unique_ptr<mlir::Pass> createInstrumentPass();
6062
std::unique_ptr<mlir::Pass> createInstrumentPass(

src/Tools/onnx-mlir-opt/RegisterPasses.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ void registerOMPasses(int optLevel) {
6767
return createConstPropONNXToONNXPass();
6868
});
6969

70+
mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
71+
return createQDQOptONNXToONNXPass();
72+
});
73+
7074
mlir::registerPass(
7175
[]() -> std::unique_ptr<mlir::Pass> { return createInstrumentPass(); });
7276

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: onnx-mlir-opt --canonicalize --dqq-opt-onnx-to-onnx %s -split-input-file | FileCheck %s
2+
3+
func.func @test_cast_pattern1(%arg0: tensor<*xui16>) -> tensor<*xui16> {
4+
%0 = onnx.Constant dense<2.57987776E-5> : tensor<f32>
5+
%1 = onnx.Constant dense<39664> : tensor<ui16>
6+
%2 = "onnx.DequantizeLinear"(%arg0, %0, %1) {axis = 1 : si64, block_size = 0 : si64} : (tensor<*xui16>, tensor<f32>, tensor<ui16>) -> tensor<*xf32>
7+
%3 = "onnx.Cast"(%2) {saturate = 1 : si64, to = f32} : (tensor<*xf32>) -> tensor<*xf32>
8+
%4 = "onnx.QuantizeLinear"(%3, %0, %1) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<*xf32>, tensor<f32>, tensor<ui16>) -> tensor<*xui16>
9+
return %4 : tensor<*xui16>
10+
}
11+
12+
// CHECK-LABEL: func.func @test_cast_pattern1(%arg0: tensor<*xui16>) -> tensor<*xui16>
13+
// CHECK-NOT: onnx.DequantizeLinear
14+
// CHECK-NOT: onnx.Cast
15+
// CHECK-NOT: onnx.QuantizeLinear
16+
17+
func.func @test_cast_pattern2(%arg0: tensor<*xui16>) -> tensor<*xui16> {
18+
%0 = onnx.Constant dense<2.57987776E-5> : tensor<f32>
19+
%1 = onnx.Constant dense<39664> : tensor<ui16>
20+
%2 = "onnx.Cast"(%arg0) {saturate = 1 : si64, to = f32} : (tensor<*xui16>) -> tensor<*xf32>
21+
%3 = "onnx.QuantizeLinear"(%2, %0, %1) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<*xf32>, tensor<f32>, tensor<ui16>) -> tensor<*xui16>
22+
return %3 : tensor<*xui16>
23+
}
24+
25+
// CHECK-LABEL: func.func @test_cast_pattern2(%arg0: tensor<*xui16>) -> tensor<*xui16>
26+
// CHECK: onnx.Cast
27+
// CHECK: onnx.QuantizeLinear
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// RUN: onnx-mlir-opt --canonicalize --dqq-opt-onnx-to-onnx %s -split-input-file | FileCheck %s
2+
3+
func.func @test_concat_pattern1(%arg0: tensor<*xui16>) -> tensor<*xui16> {
4+
%0 = onnx.Constant dense<2.57987776E-5> : tensor<f32>
5+
%1 = onnx.Constant dense<39664> : tensor<ui16>
6+
%2 = "onnx.DequantizeLinear"(%arg0, %0, %1) {axis = 1 : si64, block_size = 0 : si64} : (tensor<*xui16>, tensor<f32>, tensor<ui16>) -> tensor<*xf32>
7+
%3 = "onnx.Concat"(%2) {axis = 1 : si64} : (tensor<*xf32>) -> tensor<*xf32>
8+
%4 = "onnx.QuantizeLinear"(%3, %0, %1) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<*xf32>, tensor<f32>, tensor<ui16>) -> tensor<*xui16>
9+
return %4 : tensor<*xui16>
10+
}
11+
12+
// CHECK-LABEL: func.func @test_concat_pattern1(%arg0: tensor<*xui16>) -> tensor<*xui16>
13+
// CHECK-NOT: onnx.DequantizeLinear
14+
// CHECK-NOT: onnx.Concat
15+
// CHECK-NOT: onnx.QuantizeLinear
16+
// CHECK: return %arg0 : tensor<*xui16>
17+
18+
func.func @test_concat_pattern2(%arg0: tensor<*xui16>) -> tensor<*xui16> {
19+
%0 = onnx.Constant dense<2.57987776E-5> : tensor<f32>
20+
%1 = onnx.Constant dense<39664> : tensor<ui16>
21+
%2 = "onnx.DequantizeLinear"(%arg0, %0, %1) {axis = 1 : si64, block_size = 0 : si64} : (tensor<*xui16>, tensor<f32>, tensor<ui16>) -> tensor<*xf32>
22+
%3 = "onnx.DequantizeLinear"(%arg0, %0, %1) {axis = 1 : si64, block_size = 0 : si64} : (tensor<*xui16>, tensor<f32>, tensor<ui16>) -> tensor<*xf32>
23+
%4 = "onnx.Concat"(%2, %3) {axis = 1 : si64} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
24+
%5 = "onnx.QuantizeLinear"(%4, %0, %1) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<*xf32>, tensor<f32>, tensor<ui16>) -> tensor<*xui16>
25+
return %5 : tensor<*xui16>
26+
}
27+
28+
// CHECK-LABEL: func.func @test_concat_pattern2(%arg0: tensor<*xui16>) -> tensor<*xui16>
29+
// CHECK: onnx.DequantizeLinear
30+
// CHECK: onnx.Concat
31+
// CHECK: onnx.QuantizeLinear
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
// RUN: onnx-mlir-opt --dqq-opt-onnx-to-onnx %s -split-input-file | FileCheck %s
2+
3+
func.func @test_qdq_pattern1(%arg0: tensor<1x128x768xui16>) -> tensor<1x128x768xui16> {
4+
%0 = onnx.Constant dense<2.57987776E-5> : tensor<f32>
5+
%1 = onnx.Constant dense<39664> : tensor<ui16>
6+
%2 = "onnx.DequantizeLinear"(%arg0, %0, %1) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x128x768xui16>, tensor<f32>, tensor<ui16>) -> tensor<1x128x768xf32>
7+
%3 = "onnx.QuantizeLinear"(%2, %0, %1) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x128x768xf32>, tensor<f32>, tensor<ui16>) -> tensor<1x128x768xui16>
8+
return %3 : tensor<1x128x768xui16>
9+
10+
}
11+
12+
// CHECK-LABEL: func.func @test_qdq_pattern1(%arg0: tensor<1x128x768xui16>) -> tensor<1x128x768xui16>
13+
// CHECK: return %arg0 : tensor<1x128x768xui16>
14+
// CHECK-NOT: onnx.DequantizeLinear
15+
// CHECK-NOT: onnx.QuantizeLinear
16+
17+
func.func @test_qdq_pattern2(%arg0: tensor<1x128x768xui16>) -> tensor<1x128x768xui16> {
18+
%0 = onnx.Constant dense<2.57987776E-5> : tensor<f32>
19+
%1 = onnx.Constant dense<39664> : tensor<ui16>
20+
%2 = onnx.Constant dense<6.57987776E-5> : tensor<f32>
21+
%3 = onnx.Constant dense<45664> : tensor<ui16>
22+
%4 = "onnx.DequantizeLinear"(%arg0, %0, %1) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x128x768xui16>, tensor<f32>, tensor<ui16>) -> tensor<1x128x768xf32>
23+
%5 = "onnx.QuantizeLinear"(%4, %2, %3) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x128x768xf32>, tensor<f32>, tensor<ui16>) -> tensor<1x128x768xui16>
24+
return %5 : tensor<1x128x768xui16>
25+
}
26+
27+
// CHECK-LABEL: func.func @test_qdq_pattern2(%arg0: tensor<1x128x768xui16>) -> tensor<1x128x768xui16>
28+
// CHECK: onnx.DequantizeLinear
29+
// CHECK: onnx.QuantizeLinear
30+
31+
func.func @test_qdq_pattern3(%arg0: tensor<1x128x768xui16>) -> tensor<1x128x768xui16> {
32+
%0 = onnx.Constant dense<2.57987776E-5> : tensor<f32>
33+
%1 = onnx.Constant dense<39664> : tensor<ui16>
34+
%2 = "onnx.DequantizeLinear"(%arg0, %0, %1) {axis = 2 : si64, block_size = 0 : si64} : (tensor<1x128x768xui16>, tensor<f32>, tensor<ui16>) -> tensor<1x128x768xf32>
35+
%3 = "onnx.QuantizeLinear"(%2, %0, %1) {block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x128x768xf32>, tensor<f32>, tensor<ui16>) -> tensor<1x128x768xui16>
36+
return %3 : tensor<1x128x768xui16>
37+
38+
}
39+
40+
// CHECK-LABEL: func.func @test_qdq_pattern3(%arg0: tensor<1x128x768xui16>) -> tensor<1x128x768xui16>
41+
// CHECK: onnx.DequantizeLinear
42+
// CHECK: onnx.QuantizeLinear
43+
44+
func.func @test_qdq_pattern4(%arg0: tensor<1x128x768xui16>) -> tensor<1x128x768xui16> {
45+
%0 = onnx.Constant dense<2.57987776E-5> : tensor<f32>
46+
%1 = onnx.Constant dense<39664> : tensor<ui16>
47+
%2 = "onnx.DequantizeLinear"(%arg0, %0, %1) {axis = 1 : si64, block_size = 1 : si64} : (tensor<1x128x768xui16>, tensor<f32>, tensor<ui16>) -> tensor<1x128x768xf32>
48+
%3 = "onnx.QuantizeLinear"(%2, %0, %1) {axis = 1 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x128x768xf32>, tensor<f32>, tensor<ui16>) -> tensor<1x128x768xui16>
49+
return %3 : tensor<1x128x768xui16>
50+
51+
}
52+
53+
// CHECK-LABEL: func.func @test_qdq_pattern4(%arg0: tensor<1x128x768xui16>) -> tensor<1x128x768xui16>
54+
// CHECK: onnx.DequantizeLinear
55+
// CHECK: onnx.QuantizeLinear
56+
57+
func.func @test_qdq_pattern6(%arg0: tensor<1x128x768xui16>, %arg1: tensor<f32>) -> tensor<1x128x768xui16> {
58+
%0 = onnx.Constant dense<39664> : tensor<ui16>
59+
%1 = "onnx.DequantizeLinear"(%arg0, %arg1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x128x768xui16>, tensor<f32>, tensor<ui16>) -> tensor<1x128x768xf32>
60+
%2 = "onnx.QuantizeLinear"(%1, %arg1, %0) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x128x768xf32>, tensor<f32>, tensor<ui16>) -> tensor<1x128x768xui16>
61+
return %2 : tensor<1x128x768xui16>
62+
}
63+
64+
// CHECK-LABEL: func.func @test_qdq_pattern6(%arg0: tensor<1x128x768xui16>, %arg1: tensor<f32>) -> tensor<1x128x768xui16>
65+
// CHECK: onnx.DequantizeLinear
66+
// CHECK: onnx.QuantizeLinear
67+
68+
func.func @test_qdq_pattern7(%arg0: tensor<1x128x768xui16>, %arg1: tensor<ui16>) -> tensor<1x128x768xui16> {
69+
%0 = onnx.Constant dense<2.57987776E-5> : tensor<f32>
70+
%1 = "onnx.DequantizeLinear"(%arg0, %0, %arg1) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x128x768xui16>, tensor<f32>, tensor<ui16>) -> tensor<1x128x768xf32>
71+
%2 = "onnx.QuantizeLinear"(%1, %0, %arg1) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x128x768xf32>, tensor<f32>, tensor<ui16>) -> tensor<1x128x768xui16>
72+
return %2 : tensor<1x128x768xui16>
73+
}
74+
75+
// CHECK-LABEL: func.func @test_qdq_pattern7(%arg0: tensor<1x128x768xui16>, %arg1: tensor<ui16>) -> tensor<1x128x768xui16>
76+
// CHECK: onnx.DequantizeLinear
77+
// CHECK: onnx.QuantizeLinear
78+
79+
func.func @test_qdq_pattern8(%arg0: tensor<1x128x768xi16>) -> tensor<1x128x768xui16> {
80+
%0 = onnx.Constant dense<2.57987776E-5> : tensor<f32>
81+
%1 = onnx.Constant dense<39664> : tensor<ui16>
82+
%2 = "onnx.DequantizeLinear"(%arg0, %0, %1) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x128x768xi16>, tensor<f32>, tensor<ui16>) -> tensor<1x128x768xf32>
83+
%3 = "onnx.QuantizeLinear"(%2, %0, %1) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x128x768xf32>, tensor<f32>, tensor<ui16>) -> tensor<1x128x768xui16>
84+
return %3 : tensor<1x128x768xui16>
85+
}
86+
87+
// CHECK-LABEL: func.func @test_qdq_pattern8(%arg0: tensor<1x128x768xi16>) -> tensor<1x128x768xui16>
88+
// CHECK: onnx.DequantizeLinear
89+
// CHECK: onnx.QuantizeLinear

0 commit comments

Comments
 (0)