Skip to content

Commit 177e382

Browse files
authored
[mlir][tosa] Get quantized element type with sign info. (#169387)
As mentioned in https://github.com/llvm/llvm-project/blob/a27bb38ee6f5762e715803d8eb6ffc5a8dd09575/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h#L109 `QuantType::getStorageType` doesn't capture the sign information. This lead to the following IR to fail during verification: ``` func.func @clamp(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) { %0 = tosa.clamp %arg0 {max_val = 255 : ui8, min_val = 0 : ui8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>> return %0 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>> } ``` with `'tosa.clamp' op min/max attributes types are incompatible with input/output element types` error since `getStorageType` was returning signed integer but the clamp attributes were unsigned. This PR updates the usage of `getStorageType` in tosa codebase to correctly use the signed info for the quantized type.
1 parent ccbd0d1 commit 177e382

File tree

8 files changed

+83
-14
lines changed

8 files changed

+83
-14
lines changed

mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ TypeAttr buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDType,
8888
IntegerAttr quantBits, int filterQuantDim,
8989
bool isSigned, BoolAttr narrowRange);
9090

91+
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType);
92+
9193
} // namespace tosa
9294
} // namespace mlir
9395

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1616
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
1717
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
18+
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
1819
#include "mlir/IR/BuiltinTypeInterfaces.h"
1920
#include "mlir/IR/BuiltinTypes.h"
2021
#include "mlir/IR/Matchers.h"
@@ -539,7 +540,7 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
539540
auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType();
540541
if (auto quantType =
541542
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
542-
inputEType = quantType.getStorageType();
543+
inputEType = getStorageElementTypeFromQuantized(quantType);
543544
}
544545

545546
Attribute newMinValAttr, newMaxValAttr;

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ static std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
563563
static Type getStorageElementTypeOrSelf(Type type) {
564564
auto srcType = getElementTypeOrSelf(type);
565565
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
566-
srcType = quantType.getStorageType();
566+
srcType = getStorageElementTypeFromQuantized(quantType);
567567
return srcType;
568568
}
569569

@@ -631,16 +631,16 @@ static LogicalResult verifyConvOp(T op) {
631631
bool resultIsFloat = llvm::isa<FloatType>(resultEType);
632632

633633
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
634-
inputEType = quantType.getStorageType();
634+
inputEType = getStorageElementTypeFromQuantized(quantType);
635635

636636
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
637-
weightEType = quantType.getStorageType();
637+
weightEType = getStorageElementTypeFromQuantized(quantType);
638638

639639
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
640-
biasEType = quantType.getStorageType();
640+
biasEType = getStorageElementTypeFromQuantized(quantType);
641641

642642
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
643-
resultEType = quantType.getStorageType();
643+
resultEType = getStorageElementTypeFromQuantized(quantType);
644644

645645
if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
646646
// for now, only enforce bias element type == result element type for
@@ -709,7 +709,7 @@ LogicalResult tosa::ConstOp::verify() {
709709

710710
if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
711711
outputType.getElementType())) {
712-
if (result.getStorageType() == attrType.getElementType())
712+
if (getStorageElementTypeFromQuantized(result) == attrType.getElementType())
713713
return success();
714714
}
715715

@@ -727,7 +727,7 @@ static LogicalResult verifyConvOpModes(T op) {
727727
llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
728728

729729
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
730-
inputEType = quantType.getStorageType();
730+
inputEType = getStorageElementTypeFromQuantized(quantType);
731731

732732
auto accType = op.getAccType();
733733
if (inputEType.isInteger(8) && !accType.isInteger(32))
@@ -752,7 +752,7 @@ static LogicalResult verifyConvOpModes(T op) {
752752
llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
753753

754754
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
755-
resultEType = quantType.getStorageType();
755+
resultEType = getStorageElementTypeFromQuantized(quantType);
756756

757757
return success();
758758
}
@@ -1179,13 +1179,13 @@ LogicalResult tosa::ClampOp::verify() {
11791179
llvm::cast<ShapedType>(getInput().getType()).getElementType();
11801180
if (auto quantType =
11811181
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
1182-
inputETy = quantType.getStorageType();
1182+
inputETy = getStorageElementTypeFromQuantized(quantType);
11831183
}
11841184
mlir::Type outputETy =
11851185
llvm::cast<ShapedType>(getOutput().getType()).getElementType();
11861186
if (auto quantType =
11871187
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
1188-
outputETy = quantType.getStorageType();
1188+
outputETy = getStorageElementTypeFromQuantized(quantType);
11891189
}
11901190
if (inputETy != outputETy)
11911191
return emitOpError("input/output element types are incompatible.");

mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,3 +395,16 @@ mlir::tosa::buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype,
395395
maxAttr, quantBits, filterQuantDim,
396396
isSigned, narrowRange));
397397
}
398+
399+
Type mlir::tosa::getStorageElementTypeFromQuantized(
400+
quant::QuantizedType quantType) {
401+
auto quantEty = quantType.getStorageType();
402+
// StorageType doesn't capture the sign information
403+
// Explicitly create unsigned type if needed
404+
if (!quantType.isSigned()) {
405+
quantEty = IntegerType::get(quantEty.getContext(),
406+
quantEty.getIntOrFloatBitWidth(),
407+
IntegerType::Unsigned);
408+
}
409+
return quantEty;
410+
}

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,36 @@ func.func @clamp_twice_with_nan_propagate_ignore_is_not_single_clamp(%arg0: tens
360360
return %1 : tensor<4xi8>
361361
}
362362

363+
// -----
364+
365+
// CHECK-LABEL: @clamp_twice_with_unsigned_quantized_is_single_clamp
366+
// CHECK: tosa.clamp %arg0 {max_val = 230 : ui8, min_val = 10 : ui8}
367+
func.func @clamp_twice_with_unsigned_quantized_is_single_clamp(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) {
368+
%0 = tosa.clamp %arg0 {max_val = 240 : ui8, min_val = 10 : ui8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
369+
%1 = tosa.clamp %0 {max_val = 230 : ui8, min_val = 5 : ui8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
370+
return %1 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
371+
}
372+
373+
// -----
374+
375+
// CHECK-LABEL: @clamp_twice_with_signed_quantized_is_single_clamp
376+
// CHECK: tosa.clamp %arg0 {max_val = 110 : i8, min_val = -5 : i8}
377+
func.func @clamp_twice_with_signed_quantized_is_single_clamp(%arg0:tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) {
378+
%0 = tosa.clamp %arg0 {max_val = 110 : i8, min_val = -10 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
379+
%1 = tosa.clamp %0 {max_val = 120 : i8, min_val = -5 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
380+
return %1 : tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
381+
}
382+
383+
// CHECK-LABEL: @clamp_twice_with_signed_quantized_non_overlap_is_not_single_clamp
384+
// CHECK: %[[CLAMP_1:.*]] = tosa.clamp %arg0 {max_val = 50 : i8, min_val = -10 : i8}
385+
// CHECK-NEXT: tosa.clamp %[[CLAMP_1]] {max_val = 120 : i8, min_val = 60 : i8}
386+
func.func @clamp_twice_with_signed_quantized_non_overlap_is_not_single_clamp(%arg0:tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) {
387+
%0 = tosa.clamp %arg0 {max_val = 50 : i8, min_val = -10 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
388+
%1 = tosa.clamp %0 {max_val = 120 : i8, min_val = 60 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
389+
return %1 : tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
390+
}
391+
392+
363393
// -----
364394

365395
// CHECK-LABEL: @concat_fold

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,13 @@ func.func @test_clamp_quantized(%arg0: tensor<13x21x3x!quant.uniform<i8:f32, 1.0
279279
return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 1.000000e-01:-127>>
280280
}
281281

282+
// -----
283+
// CHECK-LABEL: clamp_quantized_unsigned
284+
func.func @clamp_quantized_unsigned(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) {
285+
%0 = tosa.clamp %arg0 {max_val = 255 : ui8, min_val = 0 : ui8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
286+
return %0 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
287+
}
288+
282289
// -----
283290
// CHECK-LABEL: sigmoid
284291
func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {

mlir/test/Dialect/Tosa/quant-test.mlir

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
// RUN: mlir-opt --tosa-test-quant-utils %s | FileCheck %s
22

33
// -----
4-
// CHECK-LABEL: test_build_qtype
5-
func.func @test_build_qtype(%arg0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>> {
4+
// CHECK-LABEL: test_build_qtype_unsigned
5+
func.func @test_build_qtype_unsigned(%arg0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, %arg1: tensor<1xui8>, %arg2: tensor<1xui8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>> {
66
// CHECK: tosa.negate
7-
%0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, tensor<1xi8>, tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>
7+
%0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, tensor<1xui8>, tensor<1xui8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>
88
return %0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>
99
}
1010

11+
// -----
12+
// CHECK-LABEL: test_build_qtype_signed
13+
func.func @test_build_qtype_signed(%arg0 : tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>> {
14+
// CHECK: tosa.negate
15+
%0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>>, tensor<1xi8>, tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>>
16+
return %0 : tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>>
17+
}
18+
1119
// -----
1220
// CHECK-LABEL: test_build_mult_and_shift
1321
func.func @test_build_mult_and_shift(%arg0: tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015684768557548523>>, %arg1 : tensor<16x1x1x8x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, %arg2 : tensor<16xi32>) -> tensor<1x34x36x16x!quant.uniform<i32:f32, 0.078431375324726104>> {

mlir/test/Dialect/Tosa/verifier.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,3 +1222,11 @@ func.func @test_cast_to_block_scaled_data_scale_channel_mismatch(%arg0: tensor<4
12221222
%0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>)
12231223
return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>
12241224
}
1225+
1226+
// -----
1227+
1228+
func.func @test_clamp_quantized(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) {
1229+
// expected-error@+1 {{'tosa.clamp' op min/max attributes types are incompatible with input/output element types.}}
1230+
%0 = tosa.clamp %arg0 {max_val = 127 : i8, min_val = -128 : i8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
1231+
return %0 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
1232+
}

0 commit comments

Comments
 (0)