Skip to content

Commit dc4bc72

Browse files
authored
ConvolutionOP verifier : allow rhs per_axis quantized and result per_tensor quantized (#2094)
Previous check was restrictive and causing test failures during integration.
1 parent 888e4ae commit dc4bc72

File tree

2 files changed

+10
-11
lines changed

2 files changed

+10
-11
lines changed

stablehlo/dialect/TypeInference.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3502,16 +3502,15 @@ LogicalResult verifyConvolutionOp(
35023502
if (noneQuantized<quant::UniformQuantizedPerAxisType>(typeEntriesPerAxis))
35033503
return success();
35043504
// convolution_c31
3505-
if (!allQuantized<quant::UniformQuantizedPerAxisType>(typeEntriesPerAxis)) {
3506-
return emitOptionalError(location,
3507-
"rhs and result are of mixed per_tensor and "
3508-
"per_axis quantized tensor type ",
3509-
rhsType, " and ", resultType);
3510-
}
3511-
35123505
auto rhsQPAType = rhsQType.dyn_cast<quant::UniformQuantizedPerAxisType>();
35133506
auto resultQPAType =
35143507
resultQType.dyn_cast<quant::UniformQuantizedPerAxisType>();
3508+
if (!rhsQPAType && resultQPAType) {
3509+
return emitOptionalError(
3510+
location, "per-tensor rhs expects per-tensor result but received ",
3511+
rhsType, " and ", resultType, " respectively");
3512+
}
3513+
35153514
// convolution_c32
35163515
if (rhsQPAType &&
35173516
rhsQPAType.getQuantizedDimension() != kernelOutputFeatureDimension)

stablehlo/tests/ops_stablehlo_quantized.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -860,14 +860,14 @@ func.func @convolution_c30(%arg0: tensor<1x8x8x207x!quant.uniform<i8:f64, 2.0:15
860860

861861
// -----
862862

863-
func.func @convolution_c31(%arg0: tensor<1x8x8x207x!quant.uniform<i8:f32, 2.0:15>>, %arg1: tensor<3x3x207x16x!quant.uniform<i8:f32:0, {0.1:-30}>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>> {
864-
// expected-error@+1 {{rhs and result are of mixed per_tensor and per_axis quantized tensor type 'tensor<3x3x207x16x!quant.uniform<i8:f32:0, {1.000000e-01:-30}>>' and 'tensor<1x8x8x16x!quant.uniform<i8:f32, 1.000000e+01:50>>'}}
863+
func.func @convolution_c31(%arg0: tensor<1x8x8x207x!quant.uniform<i8:f32, 2.0:15>>, %arg1: tensor<3x3x207x16x!quant.uniform<i8:f32, 0.1:-30>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32:0, {10.0:50}>> {
864+
// expected-error@+1 {{per-tensor rhs expects per-tensor result but received 'tensor<3x3x207x16x!quant.uniform<i8:f32, 1.000000e-01:-30>>' and 'tensor<1x8x8x16x!quant.uniform<i8:f32:0, {1.000000e+01:50}>>' respectively}}
865865
%0 = stablehlo.convolution(%arg0, %arg1)
866866
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
867867
window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
868868
{batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} :
869-
(tensor<1x8x8x207x!quant.uniform<i8:f32, 2.0:15>>, tensor<3x3x207x16x!quant.uniform<i8:f32:0, {0.1:-30}>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>>
870-
func.return %0 : tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>>
869+
(tensor<1x8x8x207x!quant.uniform<i8:f32, 2.0:15>>, tensor<3x3x207x16x!quant.uniform<i8:f32, 0.1:-30>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32:0, {10.0:50}>>
870+
func.return %0 : tensor<1x8x8x16x!quant.uniform<i8:f32:0, {10.0:50}>>
871871
}
872872

873873
// -----

0 commit comments

Comments
 (0)