Skip to content

Commit ab6bb16

Browse files
author
sushmita
committed
attr chk simplified, renaming, add flag
1 parent 9271d27 commit ab6bb16

File tree

7 files changed

+145
-176
lines changed

7 files changed

+145
-176
lines changed

src/Compiler/CompilerPasses.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
142142
opts.enableQuarkQuantizedLegalization));
143143

144144
// Passes for removing redundant concat, slice and cast QDQ Ops
145-
pm.addPass(createQDQOptONNXToONNXPass());
145+
if (opts.enableRemoveDqQOp)
146+
pm.addPass(createQDQOptONNXToONNXPass());
146147

147148
// One more call to ONNX shape inference/canonicalization/... to update
148149
// shape if possible.

src/Compiler/CompilerPasses.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ struct OnnxToMlirOptions {
3030
bool enableConvTransposeDecompose = false;
3131
bool enableConvTransposeDecomposeToPhasedConv = false;
3232
bool enableConvTranspose1dDecomposeToPhasedConv = false;
33+
bool enableRemoveDqQOp = true;
3334
};
3435

3536
void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,

src/Dialect/ONNX/Transforms/QDQOpt.cpp

Lines changed: 7 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -35,70 +35,6 @@ static ElementsAttr getElementAttributeFromConstant(Value val) {
3535
return nullptr;
3636
}
3737

38-
static mlir::LogicalResult equalsDefaultIntegerAttr(
39-
mlir::IntegerAttr ia, int64_t defaultValue) {
40-
auto it = mlir::cast<mlir::IntegerType>(ia.getType());
41-
int64_t got = it.isUnsignedInteger()
42-
? static_cast<int64_t>(ia.getValue().getZExtValue())
43-
: ia.getValue().getSExtValue();
44-
return (got == defaultValue) ? mlir::success() : mlir::failure();
45-
}
46-
47-
static mlir::LogicalResult equalsDefaultIntElements(
48-
mlir::ElementsAttr ea, int64_t defaultValue) {
49-
auto st = mlir::dyn_cast<mlir::ShapedType>(ea.getType());
50-
if (!st)
51-
return mlir::failure();
52-
mlir::Type et = st.getElementType();
53-
if (!et.isIntOrIndex())
54-
return mlir::failure();
55-
const bool isUnsigned = et.isa<mlir::IntegerType>() &&
56-
et.cast<mlir::IntegerType>().isUnsignedInteger();
57-
if (ea.isSplat()) {
58-
llvm::APInt api = ea.getSplatValue<llvm::APInt>();
59-
int64_t got = isUnsigned ? static_cast<int64_t>(api.getZExtValue())
60-
: api.getSExtValue();
61-
return (got == defaultValue) ? mlir::success() : mlir::failure();
62-
}
63-
for (const llvm::APInt &api : ea.getValues<llvm::APInt>()) {
64-
int64_t got = isUnsigned ? static_cast<int64_t>(api.getZExtValue())
65-
: api.getSExtValue();
66-
if (got != defaultValue)
67-
return mlir::failure();
68-
}
69-
return mlir::success();
70-
}
71-
72-
static mlir::LogicalResult checkAttrAgainstDefault(
73-
mlir::Attribute attr, int64_t defaultValue) {
74-
if (!attr)
75-
return mlir::failure();
76-
if (auto ia = mlir::dyn_cast<mlir::IntegerAttr>(attr))
77-
return equalsDefaultIntegerAttr(ia, defaultValue);
78-
if (auto ea = mlir::dyn_cast<mlir::ElementsAttr>(attr))
79-
return equalsDefaultIntElements(ea, defaultValue);
80-
return mlir::failure();
81-
}
82-
83-
static mlir::LogicalResult checkIntegerAttributeEquals(mlir::Operation *op1,
84-
mlir::Operation *op2, mlir::StringRef attrName, int64_t defaultValue) {
85-
mlir::Attribute attr1 = op1->getAttr(attrName);
86-
mlir::Attribute attr2 = op2->getAttr(attrName);
87-
// Case 0: both missing => both implicitly default
88-
if (!attr1 && !attr2)
89-
return mlir::success();
90-
// Case 1: both present and identical
91-
if (attr1 && attr2 && attr1 == attr2)
92-
return mlir::success();
93-
// Case 2: one side missing => present side must equal default
94-
if (!attr1)
95-
return checkAttrAgainstDefault(attr2, defaultValue);
96-
if (!attr2)
97-
return checkAttrAgainstDefault(attr1, defaultValue);
98-
// Case 3: both present but not identical
99-
return mlir::failure();
100-
}
101-
10238
//===----------------------------------------------------------------------===//
10339
// Pattern to remove QDQ pairs
10440
//===----------------------------------------------------------------------===//
@@ -107,23 +43,16 @@ struct FoldQDQPattern : public OpRewritePattern<ONNXQuantizeLinearOp> {
10743
using OpRewritePattern<ONNXQuantizeLinearOp>::OpRewritePattern;
10844
LogicalResult matchAndRewrite(
10945
ONNXQuantizeLinearOp qOp, PatternRewriter &rewriter) const override {
46+
11047
auto dqOp = qOp.getX().getDefiningOp<ONNXDequantizeLinearOp>();
11148
if (!dqOp)
11249
return failure();
11350

114-
// 1. Check attributes with defaults (axis=1, block_size=0,
115-
// saturate=1)
116-
Operation *dqOperation = dqOp.getOperation();
117-
Operation *qOperation = qOp.getOperation();
118-
119-
if (failed(
120-
checkIntegerAttributeEquals(dqOperation, qOperation, "axis", 1)) ||
121-
failed(checkIntegerAttributeEquals(
122-
dqOperation, qOperation, "block_size", 0)) ||
123-
failed(checkIntegerAttributeEquals(
124-
dqOperation, qOperation, "saturate", 1))) {
51+
// 1. Check Attributes
52+
if (qOp.getAxis() != dqOp.getAxis())
53+
return failure();
54+
if (qOp.getBlockSize() != dqOp.getBlockSize())
12555
return failure();
126-
}
12756

12857
// 2. Check zero-points
12958
auto zpAttr1 = getElementAttributeFromConstant(dqOp.getXZeroPoint());
@@ -172,9 +101,9 @@ struct QDQOptONNXToONNXPass
172101
: public PassWrapper<QDQOptONNXToONNXPass, OperationPass<func::FuncOp>> {
173102

174103
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(QDQOptONNXToONNXPass)
175-
StringRef getArgument() const override { return "qdq-opt-onnx-to-onnx"; }
104+
StringRef getArgument() const override { return "dqq-opt-onnx-to-onnx"; }
176105
StringRef getDescription() const override {
177-
return "Remove QDQ ops and surrounding QDQ if safe.";
106+
return "Remove DqQ ops and surrounding DqQ if safe.";
178107
}
179108

180109
void runOnOperation() override {
Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,27 @@
1-
// RUN: onnx-mlir-opt --canonicalize --qdq-opt-onnx-to-onnx %s -split-input-file | FileCheck %s
1+
// RUN: onnx-mlir-opt --canonicalize --dqq-opt-onnx-to-onnx %s -split-input-file | FileCheck %s
22

3-
func.func @test_cast_pattern1(%arg0: tensor<*xui16>) -> tensor<*xui16> {
4-
%0 = onnx.Constant dense<2.57987776E-5> : tensor<f32>
5-
%1 = onnx.Constant dense<39664> : tensor<ui16>
6-
%2 = "onnx.DequantizeLinear"(%arg0, %0, %1) {axis = 1 : si64, block_size = 0 : si64} : (tensor<*xui16>, tensor<f32>, tensor<ui16>) -> tensor<*xf32>
7-
%3 = "onnx.Cast"(%2) {saturate = 1 : si64, to = f32} : (tensor<*xf32>) -> tensor<*xf32>
8-
%4 = "onnx.QuantizeLinear"(%3, %0, %1) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<*xf32>, tensor<f32>, tensor<ui16>) -> tensor<*xui16>
9-
return %4 : tensor<*xui16>
10-
}
3+
func.func @test_cast_pattern1(%arg0: tensor<*xui16>) -> tensor<*xui16> {
4+
%0 = onnx.Constant dense<2.57987776E-5> : tensor<f32>
5+
%1 = onnx.Constant dense<39664> : tensor<ui16>
6+
%2 = "onnx.DequantizeLinear"(%arg0, %0, %1) {axis = 1 : si64, block_size = 0 : si64} : (tensor<*xui16>, tensor<f32>, tensor<ui16>) -> tensor<*xf32>
7+
%3 = "onnx.Cast"(%2) {saturate = 1 : si64, to = f32} : (tensor<*xf32>) -> tensor<*xf32>
8+
%4 = "onnx.QuantizeLinear"(%3, %0, %1) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<*xf32>, tensor<f32>, tensor<ui16>) -> tensor<*xui16>
9+
return %4 : tensor<*xui16>
10+
}
1111

12-
// CHECK-LABEL: func.func @test_cast_pattern1(%arg0: tensor<*xui16>) -> tensor<*xui16>
13-
// CHECK-NOT: onnx.DequantizeLinear
14-
// CHECK-NOT: onnx.Cast
15-
// CHECK-NOT: onnx.QuantizeLinear
12+
// CHECK-LABEL: func.func @test_cast_pattern1(%arg0: tensor<*xui16>) -> tensor<*xui16>
13+
// CHECK-NOT: onnx.DequantizeLinear
14+
// CHECK-NOT: onnx.Cast
15+
// CHECK-NOT: onnx.QuantizeLinear
1616

1717
func.func @test_cast_pattern2(%arg0: tensor<*xui16>) -> tensor<*xui16> {
18-
%0 = onnx.Constant dense<2.57987776E-5> : tensor<f32>
19-
%1 = onnx.Constant dense<39664> : tensor<ui16>
20-
%2 = "onnx.Cast"(%arg0) {saturate = 1 : si64, to = f32} : (tensor<*xui16>) -> tensor<*xf32>
21-
%3 = "onnx.QuantizeLinear"(%2, %0, %1) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<*xf32>, tensor<f32>, tensor<ui16>) -> tensor<*xui16>
22-
return %3 : tensor<*xui16>
18+
%0 = onnx.Constant dense<2.57987776E-5> : tensor<f32>
19+
%1 = onnx.Constant dense<39664> : tensor<ui16>
20+
%2 = "onnx.Cast"(%arg0) {saturate = 1 : si64, to = f32} : (tensor<*xui16>) -> tensor<*xf32>
21+
%3 = "onnx.QuantizeLinear"(%2, %0, %1) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<*xf32>, tensor<f32>, tensor<ui16>) -> tensor<*xui16>
22+
return %3 : tensor<*xui16>
2323
}
2424

2525
// CHECK-LABEL: func.func @test_cast_pattern2(%arg0: tensor<*xui16>) -> tensor<*xui16>
2626
// CHECK: onnx.Cast
27-
// CHECK: onnx.QuantizeLinear
27+
// CHECK: onnx.QuantizeLinear
Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,31 @@
1-
// RUN: onnx-mlir-opt --canonicalize --qdq-opt-onnx-to-onnx %s -split-input-file | FileCheck %s
1+
// RUN: onnx-mlir-opt --canonicalize --dqq-opt-onnx-to-onnx %s -split-input-file | FileCheck %s
22

3-
func.func @test_concat_pattern1(%arg0: tensor<*xui16>) -> tensor<*xui16> {
4-
%0 = onnx.Constant dense<2.57987776E-5> : tensor<f32>
5-
%1 = onnx.Constant dense<39664> : tensor<ui16>
6-
%2 = "onnx.DequantizeLinear"(%arg0, %0, %1) {axis = 1 : si64, block_size = 0 : si64} : (tensor<*xui16>, tensor<f32>, tensor<ui16>) -> tensor<*xf32>
7-
%3 = "onnx.Concat"(%2) {axis = 1 : si64} : (tensor<*xf32>) -> tensor<*xf32>
8-
%4 = "onnx.QuantizeLinear"(%3, %0, %1) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<*xf32>, tensor<f32>, tensor<ui16>) -> tensor<*xui16>
9-
return %4 : tensor<*xui16>
10-
}
3+
func.func @test_concat_pattern1(%arg0: tensor<*xui16>) -> tensor<*xui16> {
4+
%0 = onnx.Constant dense<2.57987776E-5> : tensor<f32>
5+
%1 = onnx.Constant dense<39664> : tensor<ui16>
6+
%2 = "onnx.DequantizeLinear"(%arg0, %0, %1) {axis = 1 : si64, block_size = 0 : si64} : (tensor<*xui16>, tensor<f32>, tensor<ui16>) -> tensor<*xf32>
7+
%3 = "onnx.Concat"(%2) {axis = 1 : si64} : (tensor<*xf32>) -> tensor<*xf32>
8+
%4 = "onnx.QuantizeLinear"(%3, %0, %1) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<*xf32>, tensor<f32>, tensor<ui16>) -> tensor<*xui16>
9+
return %4 : tensor<*xui16>
10+
}
1111

12-
// CHECK-LABEL: func.func @test_concat_pattern1(%arg0: tensor<*xui16>) -> tensor<*xui16>
13-
// CHECK-NOT: onnx.DequantizeLinear
14-
// CHECK-NOT: onnx.Concat
15-
// CHECK-NOT: onnx.QuantizeLinear
16-
// CHECK: return %arg0 : tensor<*xui16>
12+
// CHECK-LABEL: func.func @test_concat_pattern1(%arg0: tensor<*xui16>) -> tensor<*xui16>
13+
// CHECK-NOT: onnx.DequantizeLinear
14+
// CHECK-NOT: onnx.Concat
15+
// CHECK-NOT: onnx.QuantizeLinear
16+
// CHECK: return %arg0 : tensor<*xui16>
1717

1818
func.func @test_concat_pattern2(%arg0: tensor<*xui16>) -> tensor<*xui16> {
19-
%0 = onnx.Constant dense<2.57987776E-5> : tensor<f32>
20-
%1 = onnx.Constant dense<39664> : tensor<ui16>
21-
%2 = "onnx.DequantizeLinear"(%arg0, %0, %1) {axis = 1 : si64, block_size = 0 : si64} : (tensor<*xui16>, tensor<f32>, tensor<ui16>) -> tensor<*xf32>
22-
%3 = "onnx.DequantizeLinear"(%arg0, %0, %1) {axis = 1 : si64, block_size = 0 : si64} : (tensor<*xui16>, tensor<f32>, tensor<ui16>) -> tensor<*xf32>
23-
%4 = "onnx.Concat"(%2, %3) {axis = 1 : si64} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
24-
%5 = "onnx.QuantizeLinear"(%4, %0, %1) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<*xf32>, tensor<f32>, tensor<ui16>) -> tensor<*xui16>
25-
return %5 : tensor<*xui16>
26-
}
19+
%0 = onnx.Constant dense<2.57987776E-5> : tensor<f32>
20+
%1 = onnx.Constant dense<39664> : tensor<ui16>
21+
%2 = "onnx.DequantizeLinear"(%arg0, %0, %1) {axis = 1 : si64, block_size = 0 : si64} : (tensor<*xui16>, tensor<f32>, tensor<ui16>) -> tensor<*xf32>
22+
%3 = "onnx.DequantizeLinear"(%arg0, %0, %1) {axis = 1 : si64, block_size = 0 : si64} : (tensor<*xui16>, tensor<f32>, tensor<ui16>) -> tensor<*xf32>
23+
%4 = "onnx.Concat"(%2, %3) {axis = 1 : si64} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
24+
%5 = "onnx.QuantizeLinear"(%4, %0, %1) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<*xf32>, tensor<f32>, tensor<ui16>) -> tensor<*xui16>
25+
return %5 : tensor<*xui16>
26+
}
2727

28-
// CHECK-LABEL: func.func @test_concat_pattern2(%arg0: tensor<*xui16>) -> tensor<*xui16>
29-
// CHECK: onnx.DequantizeLinear
30-
// CHECK: onnx.Concat
31-
// CHECK: onnx.QuantizeLinear
28+
// CHECK-LABEL: func.func @test_concat_pattern2(%arg0: tensor<*xui16>) -> tensor<*xui16>
29+
// CHECK: onnx.DequantizeLinear
30+
// CHECK: onnx.Concat
31+
// CHECK: onnx.QuantizeLinear

test/mlir/onnx/onnx_remove_qdq.mlir renamed to test/mlir/onnx/onnx_remove_dqq.mlir

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: onnx-mlir-opt --qdq-opt-onnx-to-onnx %s -split-input-file | FileCheck %s
1+
// RUN: onnx-mlir-opt --dqq-opt-onnx-to-onnx %s -split-input-file | FileCheck %s
22

33
func.func @test_qdq_pattern1(%arg0: tensor<1x128x768xui16>) -> tensor<1x128x768xui16> {
44
%0 = onnx.Constant dense<2.57987776E-5> : tensor<f32>
@@ -54,19 +54,6 @@ return %3 : tensor<1x128x768xui16>
5454
// CHECK: onnx.DequantizeLinear
5555
// CHECK: onnx.QuantizeLinear
5656

57-
func.func @test_qdq_pattern5(%arg0: tensor<1x128x768xui16>) -> tensor<1x128x768xui16> {
58-
%0 = onnx.Constant dense<2.57987776E-5> : tensor<f32>
59-
%1 = onnx.Constant dense<39664> : tensor<ui16>
60-
%2 = "onnx.DequantizeLinear"(%arg0, %0, %1) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x128x768xui16>, tensor<f32>, tensor<ui16>) -> tensor<1x128x768xf32>
61-
%3 = "onnx.QuantizeLinear"(%2, %0, %1) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 2 : si64} : (tensor<1x128x768xf32>, tensor<f32>, tensor<ui16>) -> tensor<1x128x768xui16>
62-
return %3 : tensor<1x128x768xui16>
63-
64-
}
65-
66-
// CHECK-LABEL: func.func @test_qdq_pattern5(%arg0: tensor<1x128x768xui16>) -> tensor<1x128x768xui16>
67-
// CHECK: onnx.DequantizeLinear
68-
// CHECK: onnx.QuantizeLinear
69-
7057
func.func @test_qdq_pattern6(%arg0: tensor<1x128x768xui16>, %arg1: tensor<f32>) -> tensor<1x128x768xui16> {
7158
%0 = onnx.Constant dense<39664> : tensor<ui16>
7259
%1 = "onnx.DequantizeLinear"(%arg0, %arg1, %0) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x128x768xui16>, tensor<f32>, tensor<ui16>) -> tensor<1x128x768xf32>

0 commit comments

Comments
 (0)