Skip to content

Commit 1e81b4a

Browse files
committed
feat: Generate no error when scalar values with different ranks are provided for scale and zero point
1 parent da66cda commit 1e81b4a

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

src/Dialect/ONNX/ONNXOps/Quantize/DequantizeLinear.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ LogicalResult ONNXDequantizeLinearOpShapeHelper::computeShape() {
9292

9393
LogicalResult ONNXDequantizeLinearOp::verify() {
9494
// Is tensor known to be a scalar (rank 0 or rank 1 with 1 element)?
95-
auto isScalar = [](RankedTensorType t) -> bool {
95+
auto isScalar = [](ShapedType t) -> bool {
9696
return t.getRank() == 0 || (t.getRank() == 1 && t.getDimSize(0) == 1);
9797
};
9898

@@ -108,7 +108,8 @@ LogicalResult ONNXDequantizeLinearOp::verify() {
108108
Value zero = getXZeroPoint();
109109
if (!isNoneValue(zero)) {
110110
const auto zeroTy = mlir::cast<ShapedType>(zero.getType());
111-
if (zeroTy.hasRank() && scaleTy.hasRank() &&
111+
if (zeroTy.hasRank() && scaleTy.hasRank() && !isScalar(zeroTy) &&
112+
!isScalar(scaleTy) &&
112113
(zeroTy.getRank() != scaleTy.getRank() ||
113114
zeroTy.getShape() != scaleTy.getShape())) {
114115
return emitOpError("x_zero_point must have the same shape as x_scale");

test/mlir/onnx/invalid.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,16 @@ func.func @test_dequantize_linear_verifier_2(%arg0 : tensor<5x5x1xi32>, %arg1 :
115115

116116
// -----
117117

118+
// We should not get any error for scalar values with different ranks.
119+
func.func @test_scalar_with_different_rank(%arg0 : tensor<5x5x1xi32>) -> tensor<5x5x1xf32> {
120+
%0 = "onnx.Constant"(){ value = dense<1> : tensor<i32>} : () -> tensor<i32>
121+
%1 = "onnx.Constant"(){ value = dense<2.0> : tensor<1xf32> } : () -> tensor<1xf32>
122+
%2 = "onnx.DequantizeLinear"(%arg0, %1, %0) {} : (tensor<5x5x1xi32>, tensor<1xf32>, tensor<i32>) -> tensor<5x5x1xf32>
123+
"onnx.Return"(%2) : (tensor<5x5x1xf32>) -> ()
124+
}
125+
126+
// -----
127+
118128
func.func @test_dequantize_linear_verifier_3(%arg0 : tensor<5x5x1xi32>, %arg1 : tensor<3xf32>, %arg2 : tensor<3xi32>) -> tensor<*xf32> {
119129
// expected-error @+1 {{'onnx.DequantizeLinear' op x_scale and x_zero_point 1-D tensor length must match the input axis dim size}}
120130
%1 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {} : (tensor<5x5x1xi32>, tensor<3xf32>, tensor<3xi32>) -> tensor<*xf32>

0 commit comments

Comments
 (0)