-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][tosa] Get quantized element type with sign info. #169387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir Author: Sayan Saha (sahas3) ChangesAs mentioned in
QuantType::getStorageType doesn't capture the sign information. This lead to the following IR to fail during verification:
with This PR updates the usage of Full diff: https://github.com/llvm/llvm-project/pull/169387.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
index 9d9a934cdfd5e..0e751911df94d 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
@@ -19,7 +19,7 @@
#include "mlir/Dialect/Quant/Utils/FakeQuantSupport.h"
#include "mlir/Dialect/Quant/Utils/UniformSupport.h"
-namespace mlir {
+namespace mlir {
namespace tosa {
//===----------------------------------------------------------------------===//
@@ -88,6 +88,9 @@ TypeAttr buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDType,
IntegerAttr quantBits, int filterQuantDim,
bool isSigned, BoolAttr narrowRange);
+Type
+getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType);
+
} // namespace tosa
} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index a118ac9c4b111..c420a4c9596ff 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
@@ -539,7 +540,7 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType();
if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
- inputEType = quantType.getStorageType();
+ inputEType = getStorageElementTypeFromQuantized(quantType);
}
Attribute newMinValAttr, newMaxValAttr;
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 65e0a59d39168..1c175f9ab0207 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -563,7 +563,7 @@ static std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
static Type getStorageElementTypeOrSelf(Type type) {
auto srcType = getElementTypeOrSelf(type);
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
- srcType = quantType.getStorageType();
+ srcType = getStorageElementTypeFromQuantized(quantType);
return srcType;
}
@@ -631,16 +631,16 @@ static LogicalResult verifyConvOp(T op) {
bool resultIsFloat = llvm::isa<FloatType>(resultEType);
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
- inputEType = quantType.getStorageType();
+ inputEType = getStorageElementTypeFromQuantized(quantType);
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
- weightEType = quantType.getStorageType();
+ weightEType = getStorageElementTypeFromQuantized(quantType);
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
- biasEType = quantType.getStorageType();
+ biasEType = getStorageElementTypeFromQuantized(quantType);
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
- resultEType = quantType.getStorageType();
+ resultEType = getStorageElementTypeFromQuantized(quantType);
if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
// for now, only enforce bias element type == result element type for
@@ -709,7 +709,7 @@ LogicalResult tosa::ConstOp::verify() {
if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
outputType.getElementType())) {
- if (result.getStorageType() == attrType.getElementType())
+ if (getStorageElementTypeFromQuantized(result) == attrType.getElementType())
return success();
}
@@ -727,7 +727,7 @@ static LogicalResult verifyConvOpModes(T op) {
llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
- inputEType = quantType.getStorageType();
+ inputEType = getStorageElementTypeFromQuantized(quantType);
auto accType = op.getAccType();
if (inputEType.isInteger(8) && !accType.isInteger(32))
@@ -752,7 +752,7 @@ static LogicalResult verifyConvOpModes(T op) {
llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
- resultEType = quantType.getStorageType();
+ resultEType = getStorageElementTypeFromQuantized(quantType);
return success();
}
@@ -1179,13 +1179,13 @@ LogicalResult tosa::ClampOp::verify() {
llvm::cast<ShapedType>(getInput().getType()).getElementType();
if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
- inputETy = quantType.getStorageType();
+ inputETy = getStorageElementTypeFromQuantized(quantType);
}
mlir::Type outputETy =
llvm::cast<ShapedType>(getOutput().getType()).getElementType();
if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
- outputETy = quantType.getStorageType();
+ outputETy = getStorageElementTypeFromQuantized(quantType);
}
if (inputETy != outputETy)
return emitOpError("input/output element types are incompatible.");
diff --git a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
index 02c86a090e6d4..c55b13dc98cc5 100644
--- a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
@@ -395,3 +395,16 @@ mlir::tosa::buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype,
maxAttr, quantBits, filterQuantDim,
isSigned, narrowRange));
}
+
+Type mlir::tosa::getStorageElementTypeFromQuantized(
+ quant::QuantizedType quantType) {
+ auto quantEty = quantType.getStorageType();
+ // StorageType doesn't capture the sign information
+ // Explicitly create unsigned type if needed
+ if (!quantType.isSigned()) {
+ quantEty = IntegerType::get(quantEty.getContext(),
+ quantEty.getIntOrFloatBitWidth(),
+ IntegerType::Unsigned);
+ }
+ return quantEty;
+}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index fc5ea7710e2c4..84776c47b628d 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -360,6 +360,36 @@ func.func @clamp_twice_with_nan_propagate_ignore_is_not_single_clamp(%arg0: tens
return %1 : tensor<4xi8>
}
+// -----
+
+// CHECK-LABEL: @clamp_twice_with_unsigned_quantized_is_single_clamp
+// CHECK: tosa.clamp %arg0 {max_val = 230 : ui8, min_val = 10 : ui8}
+func.func @clamp_twice_with_unsigned_quantized_is_single_clamp(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) {
+ %0 = tosa.clamp %arg0 {max_val = 240 : ui8, min_val = 10 : ui8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+ %1 = tosa.clamp %0 {max_val = 230 : ui8, min_val = 5 : ui8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+ return %1 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+}
+
+// -----
+
+// CHECK-LABEL: @clamp_twice_with_signed_quantized_is_single_clamp
+// CHECK: tosa.clamp %arg0 {max_val = 110 : i8, min_val = -5 : i8}
+func.func @clamp_twice_with_signed_quantized_is_single_clamp(%arg0:tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) {
+ %0 = tosa.clamp %arg0 {max_val = 110 : i8, min_val = -10 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
+ %1 = tosa.clamp %0 {max_val = 120 : i8, min_val = -5 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
+ return %1 : tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
+}
+
+// CHECK-LABEL: @clamp_twice_with_signed_quantized_non_overlap_is_not_single_clamp
+// CHECK: %[[CLAMP_1:.*]] = tosa.clamp %arg0 {max_val = 50 : i8, min_val = -10 : i8}
+// CHECK-NEXT: tosa.clamp %[[CLAMP_1]] {max_val = 120 : i8, min_val = 60 : i8}
+func.func @clamp_twice_with_signed_quantized_non_overlap_is_not_single_clamp(%arg0:tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) {
+ %0 = tosa.clamp %arg0 {max_val = 50 : i8, min_val = -10 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
+ %1 = tosa.clamp %0 {max_val = 120 : i8, min_val = 60 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
+ return %1 : tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
+}
+
+
// -----
// CHECK-LABEL: @concat_fold
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index a4591f7ffd393..652447bd6056e 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -279,6 +279,13 @@ func.func @test_clamp_quantized(%arg0: tensor<13x21x3x!quant.uniform<i8:f32, 1.0
return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 1.000000e-01:-127>>
}
+// -----
+// CHECK-LABEL: clamp_quantized_unsigned
+func.func @clamp_quantized_unsigned(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) {
+ %0 = tosa.clamp %arg0 {max_val = 255 : ui8, min_val = 0 : ui8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+ return %0 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+}
+
// -----
// CHECK-LABEL: sigmoid
func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
diff --git a/mlir/test/Dialect/Tosa/quant-test.mlir b/mlir/test/Dialect/Tosa/quant-test.mlir
index f0ad4eb4fdb0b..88dffe7fdd2e8 100644
--- a/mlir/test/Dialect/Tosa/quant-test.mlir
+++ b/mlir/test/Dialect/Tosa/quant-test.mlir
@@ -1,13 +1,21 @@
// RUN: mlir-opt --tosa-test-quant-utils %s | FileCheck %s
// -----
-// CHECK-LABEL: test_build_qtype
-func.func @test_build_qtype(%arg0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>> {
+// CHECK-LABEL: test_build_qtype_unsigned
+func.func @test_build_qtype_unsigned(%arg0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, %arg1: tensor<1xui8>, %arg2: tensor<1xui8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>> {
// CHECK: tosa.negate
- %0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, tensor<1xi8>, tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>
+ %0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, tensor<1xui8>, tensor<1xui8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>
return %0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>
}
+// -----
+// CHECK-LABEL: test_build_qtype_signed
+func.func @test_build_qtype_signed(%arg0 : tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>> {
+ // CHECK: tosa.negate
+ %0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>>, tensor<1xi8>, tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>>
+ return %0 : tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>>
+}
+
// -----
// CHECK-LABEL: test_build_mult_and_shift
func.func @test_build_mult_and_shift(%arg0: tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015684768557548523>>, %arg1 : tensor<16x1x1x8x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, %arg2 : tensor<16xi32>) -> tensor<1x34x36x16x!quant.uniform<i32:f32, 0.078431375324726104>> {
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 6cf76cdc7ad8e..ea64d468f151e 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1222,3 +1222,11 @@ func.func @test_cast_to_block_scaled_data_scale_channel_mismatch(%arg0: tensor<4
%0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>)
return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>
}
+
+// -----
+
+func.func @test_clamp_quantized(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) {
+ // expected-error@+1 {{'tosa.clamp' op min/max attributes types are incompatible with input/output element types.}}
+ %0 = tosa.clamp %arg0 {max_val = 127 : i8, min_val = -128 : i8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+ return %0 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+}
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
|
Found this issue when running as it produces an invalid Planning to fix that issue in the tensorflow repo once this PR lands and it available there. |
lhutton1
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @sahas3!
As mentioned in
llvm-project/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
Line 109 in a27bb38
QuantType::getStorageTypedoesn't capture the sign information. This lead to the following IR to fail during verification:with
'tosa.clamp' op min/max attributes types are incompatible with input/output element typeserrorsince
getStorageTypewas returning signed integer but the clamp attributes were unsigned.This PR updates the usage of
getStorageTypein tosa codebase to correctly use the signed info for the quantized type.