Skip to content

Commit a8f44d0

Browse files
authored
Updates to ConvolutionOP verifier to support quantization constraints (#2079)
1 parent d978408 commit a8f44d0

File tree

2 files changed

+146
-0
lines changed

2 files changed

+146
-0
lines changed

stablehlo/dialect/TypeInference.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,25 @@ limitations under the License.
6565

6666
namespace mlir {
6767
namespace hlo {
68+
namespace {
69+
//===----------------------------------------------------------------------===//
70+
// Utils for quantization specific verifications
71+
//===----------------------------------------------------------------------===//
72+
template <typename T>
73+
bool allQuantized(ArrayRef<Type> typeRange) {
74+
return llvm::all_of(typeRange, [&](Type val) {
75+
return val.cast<ShapedType>().getElementType().isa<T>();
76+
});
77+
}
78+
79+
template <typename T>
80+
bool noneQuantized(ArrayRef<Type> typeRange) {
81+
return llvm::all_of(typeRange, [&](Type val) {
82+
return !val.cast<ShapedType>().getElementType().isa<T>();
83+
});
84+
}
85+
86+
} // namespace
6887

6988
//===----------------------------------------------------------------------===//
7089
// Utils for shape functions.
@@ -3453,6 +3472,61 @@ LogicalResult verifyConvolutionOp(
34533472
"is incompatible with return type of operation ",
34543473
shapedResultType, "");
34553474

3475+
llvm::SmallVector<Type, 3> typeEntries{lhsType, rhsType, resultType};
3476+
if (noneQuantized<quant::QuantizedType>(typeEntries)) return success();
3477+
// convolution_c28
3478+
if (!allQuantized<quant::QuantizedType>(typeEntries)) {
3479+
return emitOptionalError(location,
3480+
"not all of operands and result are quantized");
3481+
}
3482+
3483+
auto lhsQType =
3484+
getElementTypeOrSelf(lhsType).dyn_cast<quant::QuantizedType>();
3485+
auto rhsQType =
3486+
getElementTypeOrSelf(rhsType).dyn_cast<quant::QuantizedType>();
3487+
auto resultQType =
3488+
getElementTypeOrSelf(resultType).dyn_cast<quant::QuantizedType>();
3489+
// convolution_c29
3490+
if (lhsQType.getStorageType() != rhsQType.getStorageType())
3491+
return emitOptionalError(location, "mismatched operand storage types ",
3492+
lhsQType.getStorageType(), " and ",
3493+
rhsQType.getStorageType());
3494+
// convolution_c30
3495+
auto expressedType = lhsQType.getExpressedType();
3496+
if (expressedType != rhsQType.getExpressedType() ||
3497+
expressedType != resultQType.getExpressedType())
3498+
return emitOptionalError(location,
3499+
"mismatched operands and result expressed types");
3500+
3501+
llvm::SmallVector<Type, 2> typeEntriesPerAxis{rhsType, resultType};
3502+
if (noneQuantized<quant::UniformQuantizedPerAxisType>(typeEntriesPerAxis))
3503+
return success();
3504+
// convolution_c31
3505+
if (!allQuantized<quant::UniformQuantizedPerAxisType>(typeEntriesPerAxis)) {
3506+
return emitOptionalError(location,
3507+
"rhs and result are of mixed per_tensor and "
3508+
"per_axis quantized tensor type ",
3509+
rhsType, " and ", resultType);
3510+
}
3511+
3512+
auto rhsQPAType = rhsQType.dyn_cast<quant::UniformQuantizedPerAxisType>();
3513+
auto resultQPAType =
3514+
resultQType.dyn_cast<quant::UniformQuantizedPerAxisType>();
3515+
// convolution_c32
3516+
if (rhsQPAType &&
3517+
rhsQPAType.getQuantizedDimension() != kernelOutputFeatureDimension)
3518+
return emitOptionalError(
3519+
location, "mismatched kernel_output_feature_dimension ",
3520+
kernelOutputFeatureDimension, " and rhs quantized dimension ",
3521+
rhsQPAType.getQuantizedDimension());
3522+
// convolution_c33
3523+
if (resultQPAType &&
3524+
resultQPAType.getQuantizedDimension() != outputFeatureDimension)
3525+
return emitOptionalError(location, "mismatched output_feature_dimension ",
3526+
outputFeatureDimension,
3527+
" and result quantized dimension ",
3528+
resultQPAType.getQuantizedDimension());
3529+
34563530
return success();
34573531
}
34583532

stablehlo/tests/ops_stablehlo_quantized.mlir

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -821,3 +821,75 @@ func.func @illegal_storage_type_for_quantized_element_type(%arg0: tensor<4x!quan
821821
%0 = "stablehlo.uniform_dequantize"(%arg0) : (tensor<4x!quant.uniform<si8:f32, 1.000000e+00>>) -> tensor<4xf32>
822822
func.return %0 : tensor<4xf32>
823823
}
824+
825+
// -----
826+
827+
func.func @convolution_c28(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16x!quant.uniform<i8:f32, 5.0:20>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>> {
828+
// expected-error@+1 {{not all of operands and result are quantized}}
829+
%0 = stablehlo.convolution(%arg0, %arg1)
830+
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
831+
window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
832+
{batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} :
833+
(tensor<1x8x8x207xf32>, tensor<3x3x207x16x!quant.uniform<i8:f32, 5.0:20>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>>
834+
func.return %0 : tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>>
835+
}
836+
837+
// -----
838+
839+
func.func @convolution_c29(%arg0: tensor<1x8x8x207x!quant.uniform<i16:f32, 2.0:15>>, %arg1: tensor<3x3x207x16x!quant.uniform<i8:f32, 5.0:20>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>> {
840+
// expected-error@+1 {{mismatched operand storage types 'i16' and 'i8'}}
841+
%0 = stablehlo.convolution(%arg0, %arg1)
842+
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
843+
window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
844+
{batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} :
845+
(tensor<1x8x8x207x!quant.uniform<i16:f32, 2.0:15>>, tensor<3x3x207x16x!quant.uniform<i8:f32, 5.0:20>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>>
846+
func.return %0 : tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>>
847+
}
848+
849+
// -----
850+
851+
func.func @convolution_c30(%arg0: tensor<1x8x8x207x!quant.uniform<i8:f64, 2.0:15>>, %arg1: tensor<3x3x207x16x!quant.uniform<i8:f32, 5.0:20>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>> {
852+
// expected-error@+1 {{mismatched operands and result expressed types}}
853+
%0 = stablehlo.convolution(%arg0, %arg1)
854+
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
855+
window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
856+
{batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} :
857+
(tensor<1x8x8x207x!quant.uniform<i8:f64, 2.0:15>>, tensor<3x3x207x16x!quant.uniform<i8:f32, 5.0:20>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>>
858+
func.return %0 : tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>>
859+
}
860+
861+
// -----
862+
863+
func.func @convolution_c31(%arg0: tensor<1x8x8x207x!quant.uniform<i8:f32, 2.0:15>>, %arg1: tensor<3x3x207x16x!quant.uniform<i8:f32:0, {0.1:-30}>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>> {
864+
// expected-error@+1 {{rhs and result are of mixed per_tensor and per_axis quantized tensor type 'tensor<3x3x207x16x!quant.uniform<i8:f32:0, {1.000000e-01:-30}>>' and 'tensor<1x8x8x16x!quant.uniform<i8:f32, 1.000000e+01:50>>'}}
865+
%0 = stablehlo.convolution(%arg0, %arg1)
866+
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
867+
window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
868+
{batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} :
869+
(tensor<1x8x8x207x!quant.uniform<i8:f32, 2.0:15>>, tensor<3x3x207x16x!quant.uniform<i8:f32:0, {0.1:-30}>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>>
870+
func.return %0 : tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>>
871+
}
872+
873+
// -----
874+
875+
func.func @convolution_c32(%arg0: tensor<1x8x8x207x!quant.uniform<i8:f32, 2.0:15>>, %arg1: tensor<3x3x207x16x!quant.uniform<i8:f32:0, {0.1:-30}>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32:0, {0.1:-30}>> {
876+
// expected-error@+1 {{mismatched kernel_output_feature_dimension 3 and rhs quantized dimension 0}}
877+
%0 = stablehlo.convolution(%arg0, %arg1)
878+
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
879+
window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
880+
{batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} :
881+
(tensor<1x8x8x207x!quant.uniform<i8:f32, 2.0:15>>, tensor<3x3x207x16x!quant.uniform<i8:f32:0, {0.1:-30}>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32:0, {0.1:-30}>>
882+
func.return %0 : tensor<1x8x8x16x!quant.uniform<i8:f32:0, {0.1:-30}>>
883+
}
884+
885+
// -----
886+
887+
func.func @convolution_c33(%arg0: tensor<1x8x8x207x!quant.uniform<i8:f32, 2.0:15>>, %arg1: tensor<3x3x207x16x!quant.uniform<i8:f32:3, {0.1:-30}>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32:0, {2.0:-30}>> {
888+
// expected-error@+1 {{mismatched output_feature_dimension 3 and result quantized dimension 0}}
889+
%0 = stablehlo.convolution(%arg0, %arg1)
890+
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
891+
window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
892+
{batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} :
893+
(tensor<1x8x8x207x!quant.uniform<i8:f32, 2.0:15>>, tensor<3x3x207x16x!quant.uniform<i8:f32:3, {0.1:-30}>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32:0, {2.0:-30}>>
894+
func.return %0 : tensor<1x8x8x16x!quant.uniform<i8:f32:0, {2.0:-30}>>
895+
}

0 commit comments

Comments
 (0)