Skip to content

Conversation

@sahas3
Copy link
Member

@sahas3 sahas3 commented Nov 24, 2025

As mentioned in

/// be signed or unsigned. Use the isSigned() accessor to differentiate.
QuantType::getStorageType doesn't capture the sign information. This lead to the following IR to fail during verification:

func.func @clamp(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) {
    %0 = tosa.clamp %arg0 {max_val = 255 : ui8, min_val = 0 : ui8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
    return %0 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
}

with 'tosa.clamp' op min/max attributes types are incompatible with input/output element types error
since getStorageType was returning signed integer but the clamp attributes were unsigned.

This PR updates the usage of getStorageType in tosa codebase to correctly use the signed info for the quantized type.

@llvmbot
Copy link
Member

llvmbot commented Nov 24, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Sayan Saha (sahas3)

Changes

As mentioned in

/// be signed or unsigned. Use the isSigned() accessor to differentiate.
QuantType::getStorageType doesn't capture the sign information. This lead to the following IR to fail during verification:

func.func @<!-- -->clamp(%arg0:tensor&lt;?x112x112x32x!quant.uniform&lt;u8:f32, 0.023529412224888802:-128&gt;&gt;) -&gt; (tensor&lt;?x112x112x32x!quant.uniform&lt;u8:f32, 0.023529412224888802:-128&gt;&gt;) {
    %0 = tosa.clamp %arg0 {max_val = 255 : ui8, min_val = 0 : ui8} : (tensor&lt;?x112x112x32x!quant.uniform&lt;u8:f32, 0.023529412224888802:-128&gt;&gt;) -&gt; tensor&lt;?x112x112x32x!quant.uniform&lt;u8:f32, 0.023529412224888802:-128&gt;&gt;
    return %0 : tensor&lt;?x112x112x32x!quant.uniform&lt;u8:f32, 0.023529412224888802:-128&gt;&gt;
}

with 'tosa.clamp' op min/max attributes types are incompatible with input/output element types error
since getStorageType was returning signed integer but the clamp attributes were unsigned.

This PR updates the usage of getStorageType in tosa codebase to correctly use the signed info for the quantized type.


Full diff: https://github.com/llvm/llvm-project/pull/169387.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h (+4-1)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+10-10)
  • (modified) mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp (+13)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+30)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+7)
  • (modified) mlir/test/Dialect/Tosa/quant-test.mlir (+11-3)
  • (modified) mlir/test/Dialect/Tosa/verifier.mlir (+8)
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
index 9d9a934cdfd5e..0e751911df94d 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
@@ -19,7 +19,7 @@
 #include "mlir/Dialect/Quant/Utils/FakeQuantSupport.h"
 #include "mlir/Dialect/Quant/Utils/UniformSupport.h"
 
-namespace mlir {
+namespace mlir {    
 namespace tosa {
 
 //===----------------------------------------------------------------------===//
@@ -88,6 +88,9 @@ TypeAttr buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDType,
                                   IntegerAttr quantBits, int filterQuantDim,
                                   bool isSigned, BoolAttr narrowRange);
 
+Type
+getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType);
+
 } // namespace tosa
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index a118ac9c4b111..c420a4c9596ff 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Matchers.h"
@@ -539,7 +540,7 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
     auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType();
     if (auto quantType =
             llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
-      inputEType = quantType.getStorageType();
+      inputEType = getStorageElementTypeFromQuantized(quantType);
     }
 
     Attribute newMinValAttr, newMaxValAttr;
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 65e0a59d39168..1c175f9ab0207 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -563,7 +563,7 @@ static std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
 static Type getStorageElementTypeOrSelf(Type type) {
   auto srcType = getElementTypeOrSelf(type);
   if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
-    srcType = quantType.getStorageType();
+    srcType = getStorageElementTypeFromQuantized(quantType);
   return srcType;
 }
 
@@ -631,16 +631,16 @@ static LogicalResult verifyConvOp(T op) {
   bool resultIsFloat = llvm::isa<FloatType>(resultEType);
 
   if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
-    inputEType = quantType.getStorageType();
+    inputEType = getStorageElementTypeFromQuantized(quantType);
 
   if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
-    weightEType = quantType.getStorageType();
+    weightEType = getStorageElementTypeFromQuantized(quantType);
 
   if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
-    biasEType = quantType.getStorageType();
+    biasEType = getStorageElementTypeFromQuantized(quantType);
 
   if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
-    resultEType = quantType.getStorageType();
+    resultEType = getStorageElementTypeFromQuantized(quantType);
 
   if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
     // for now, only enforce bias element type == result element type for
@@ -709,7 +709,7 @@ LogicalResult tosa::ConstOp::verify() {
 
   if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
           outputType.getElementType())) {
-    if (result.getStorageType() == attrType.getElementType())
+    if (getStorageElementTypeFromQuantized(result) == attrType.getElementType())
       return success();
   }
 
@@ -727,7 +727,7 @@ static LogicalResult verifyConvOpModes(T op) {
       llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
 
   if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
-    inputEType = quantType.getStorageType();
+    inputEType = getStorageElementTypeFromQuantized(quantType);
 
   auto accType = op.getAccType();
   if (inputEType.isInteger(8) && !accType.isInteger(32))
@@ -752,7 +752,7 @@ static LogicalResult verifyConvOpModes(T op) {
       llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
 
   if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
-    resultEType = quantType.getStorageType();
+    resultEType = getStorageElementTypeFromQuantized(quantType);
 
   return success();
 }
@@ -1179,13 +1179,13 @@ LogicalResult tosa::ClampOp::verify() {
       llvm::cast<ShapedType>(getInput().getType()).getElementType();
   if (auto quantType =
           llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
-    inputETy = quantType.getStorageType();
+    inputETy = getStorageElementTypeFromQuantized(quantType);
   }
   mlir::Type outputETy =
       llvm::cast<ShapedType>(getOutput().getType()).getElementType();
   if (auto quantType =
           llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
-    outputETy = quantType.getStorageType();
+    outputETy = getStorageElementTypeFromQuantized(quantType);
   }
   if (inputETy != outputETy)
     return emitOpError("input/output element types are incompatible.");
diff --git a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
index 02c86a090e6d4..c55b13dc98cc5 100644
--- a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
@@ -395,3 +395,16 @@ mlir::tosa::buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype,
                                             maxAttr, quantBits, filterQuantDim,
                                             isSigned, narrowRange));
 }
+
+Type mlir::tosa::getStorageElementTypeFromQuantized(
+    quant::QuantizedType quantType) {
+  auto quantEty = quantType.getStorageType();
+  // StorageType doesn't capture the sign information
+  // Explicitly create unsigned type if needed
+  if (!quantType.isSigned()) {
+    quantEty = IntegerType::get(quantEty.getContext(),
+                                quantEty.getIntOrFloatBitWidth(),
+                                IntegerType::Unsigned);
+  }
+  return quantEty;
+}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index fc5ea7710e2c4..84776c47b628d 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -360,6 +360,36 @@ func.func @clamp_twice_with_nan_propagate_ignore_is_not_single_clamp(%arg0: tens
   return %1 : tensor<4xi8>
 }
 
+// -----
+
+// CHECK-LABEL: @clamp_twice_with_unsigned_quantized_is_single_clamp
+// CHECK: tosa.clamp %arg0 {max_val = 230 : ui8, min_val = 10 : ui8}
+func.func @clamp_twice_with_unsigned_quantized_is_single_clamp(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) {
+    %0 = tosa.clamp %arg0 {max_val = 240 : ui8, min_val = 10 : ui8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+    %1 = tosa.clamp %0 {max_val = 230 : ui8, min_val = 5 : ui8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+    return %1 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+}
+
+// -----
+
+// CHECK-LABEL: @clamp_twice_with_signed_quantized_is_single_clamp
+// CHECK: tosa.clamp %arg0 {max_val = 110 : i8, min_val = -5 : i8}
+func.func @clamp_twice_with_signed_quantized_is_single_clamp(%arg0:tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) {
+    %0 = tosa.clamp %arg0 {max_val = 110 : i8, min_val = -10 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
+    %1 = tosa.clamp %0 {max_val = 120 : i8, min_val = -5 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
+    return %1 : tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
+}
+
+// CHECK-LABEL: @clamp_twice_with_signed_quantized_non_overlap_is_not_single_clamp
+// CHECK: %[[CLAMP_1:.*]] = tosa.clamp %arg0 {max_val = 50 : i8, min_val = -10 : i8}
+// CHECK-NEXT: tosa.clamp %[[CLAMP_1]] {max_val = 120 : i8, min_val = 60 : i8}
+func.func @clamp_twice_with_signed_quantized_non_overlap_is_not_single_clamp(%arg0:tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) {
+    %0 = tosa.clamp %arg0 {max_val = 50 : i8, min_val = -10 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
+    %1 = tosa.clamp %0 {max_val = 120 : i8, min_val = 60 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
+    return %1 : tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
+}
+
+
 // -----
 
 // CHECK-LABEL: @concat_fold
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index a4591f7ffd393..652447bd6056e 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -279,6 +279,13 @@ func.func @test_clamp_quantized(%arg0: tensor<13x21x3x!quant.uniform<i8:f32, 1.0
   return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 1.000000e-01:-127>>
 }
 
+// -----
+// CHECK-LABEL: clamp_quantized_unsigned
+func.func @clamp_quantized_unsigned(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) {
+    %0 = tosa.clamp %arg0 {max_val = 255 : ui8, min_val = 0 : ui8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+    return %0 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+}
+
 // -----
 // CHECK-LABEL: sigmoid
 func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
diff --git a/mlir/test/Dialect/Tosa/quant-test.mlir b/mlir/test/Dialect/Tosa/quant-test.mlir
index f0ad4eb4fdb0b..88dffe7fdd2e8 100644
--- a/mlir/test/Dialect/Tosa/quant-test.mlir
+++ b/mlir/test/Dialect/Tosa/quant-test.mlir
@@ -1,13 +1,21 @@
 // RUN: mlir-opt --tosa-test-quant-utils %s | FileCheck %s
 
 // -----
-// CHECK-LABEL: test_build_qtype
-func.func @test_build_qtype(%arg0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>> {
+// CHECK-LABEL: test_build_qtype_unsigned
+func.func @test_build_qtype_unsigned(%arg0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, %arg1: tensor<1xui8>, %arg2: tensor<1xui8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>> {
   //  CHECK: tosa.negate
-  %0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, tensor<1xi8>, tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>
+  %0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, tensor<1xui8>, tensor<1xui8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>
   return %0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>
 }
 
+// -----
+// CHECK-LABEL: test_build_qtype_signed
+func.func @test_build_qtype_signed(%arg0 : tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>> {
+  //  CHECK: tosa.negate
+  %0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>>, tensor<1xi8>, tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>>
+  return %0 : tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>>
+}
+
 // -----
 // CHECK-LABEL: test_build_mult_and_shift
 func.func @test_build_mult_and_shift(%arg0: tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015684768557548523>>, %arg1 : tensor<16x1x1x8x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, %arg2 : tensor<16xi32>) -> tensor<1x34x36x16x!quant.uniform<i32:f32, 0.078431375324726104>> {
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 6cf76cdc7ad8e..ea64d468f151e 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1222,3 +1222,11 @@ func.func @test_cast_to_block_scaled_data_scale_channel_mismatch(%arg0: tensor<4
   %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>)
   return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>
 }
+
+// -----
+
+func.func @test_clamp_quantized(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) {
+    // expected-error@+1 {{'tosa.clamp' op min/max attributes types are incompatible with input/output element types.}}
+    %0 = tosa.clamp %arg0 {max_val = 127 : i8, min_val = -128 : i8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+    return %0 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+}

@github-actions
Copy link

github-actions bot commented Nov 24, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@sahas3
Copy link
Member Author

sahas3 commented Nov 24, 2025

Found this issue when running tf-tosa-opt --tosa-legalize-tfl on the following IR:

func.func @main(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) {
    %0 = "tfl.relu"(%arg0) : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
    return %0 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
}

as it produces an invalid tosa.clamp due to the same root cause about not considering sign info:

 "tosa.clamp"(%3) <{max_val = -1 : i8, min_val = 0 : i8, nan_mode = "PROPAGATE"}> : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>

Planning to fix that issue in the tensorflow repo once this PR lands and it available there.

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @sahas3!

@lhutton1 lhutton1 changed the title [tosa] : Get quantized element type with sign info. [mlir][tosa] Get quantized element type with sign info. Nov 25, 2025
@sahas3 sahas3 merged commit 177e382 into llvm:main Nov 25, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants