Skip to content

Conversation

@alaa-ali
Copy link
Contributor

@alaa-ali alaa-ali commented Mar 6, 2025

This PR fixes an issue related to integer overflow when casting a value greater than or equals
float MAX from float32 or float64 to int32 or int64 during tosa-to-linalg pass for tosa.cast.

This issue was found while debugging a numerical mismatch between tf.cast and tfl.cast.
tfl.cast is lowered to tosa.cast that casts between these types. The expected values were also confirmed in PyTorch using torch.Tensor.to to cast between similar dtypes and we chose to fix this overflow issue in order to match the results with Tensorflow casting and PyTorch casting.

So, without this PR, Tensorflow and PyTorch have similar behavior when casting float max to integer. But TOSA implementation has a different SPEC which makes overflow results must be saturated in TOSA..

Example of casting F64 min / max value to I64:
EXPECTED (tf.cast results):
[-9223372036854775808, -9223372036854775808]
FOUND (tosa.cast results):
[-9223372036854775808, 9223372036854775807]

Example of casting F32 min / max value to I32:
EXPECTED (tf.cast results):
[-2147483648, -2147483648 ]
FOUND (tosa.cast results):
[-2147483648, 2147483647 ]

@github-actions
Copy link

github-actions bot commented Mar 6, 2025

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Mar 6, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir-linalg

Author: Alaa Ali (alaa-ali)

Changes

This PR fixes an issue related to integer overflow when casting a value greater than or equals
float MAX from float32 or float64 to int32 or int64 during tosa-to-linalg pass for tosa.cast.

This issue was found while debugging a numerical mismatch between tf.cast and tfl.cast.
tfl.cast is lowered to tosa.cast that casts between these types. The expected values were also confirmed in PyTorch using torch.Tensor.to to cast between similar dtypes and we chose to fix this overflow issue in order to match the results with Tensorflow casting and PyTorch casting.

Example of casting F64 min / max value to I64:
EXPECTED (tf.cast results):
[-9223372036854775808, -9223372036854775808]
FOUND (tosa.cast results):
[-9223372036854775808, 9223372036854775807]

Example of casting F32 min / max value to I32:
EXPECTED (tf.cast results):
[-2147483648, -2147483648 ]
FOUND (tosa.cast results):
[-2147483648, 2147483647 ]


Full diff: https://github.com/llvm/llvm-project/pull/130116.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+8-9)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+10-9)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 8732ddafa24d4..8085f1104a4cb 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -618,12 +618,8 @@ static Value createLinalgBodyCalculationForElementwiseOp(
             loc, rewriter.getIntegerAttr(
                      getElementTypeOrSelf(dstTy),
                      APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
-        auto intMax = rewriter.create<arith::ConstantOp>(
-            loc, rewriter.getIntegerAttr(
-                     getElementTypeOrSelf(dstTy),
-                     APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
         auto maxClamped =
-            rewriter.create<arith::SelectOp>(loc, overflow, intMax, conv);
+            rewriter.create<arith::SelectOp>(loc, overflow, intMin, conv);
         return rewriter.create<arith::SelectOp>(loc, underflow, intMin,
                                                 maxClamped);
       }
@@ -647,8 +643,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
                      APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
                          .getSExtValue()));
 
+        auto overflow = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, rounded, intMaxFP);
+        Value maxClampedFP = rewriter.create<arith::SelectOp>(loc, overflow, intMinFP, rounded);
+
         Value clamped =
-            clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
+            clampFloatHelper(loc, maxClampedFP, intMinFP, intMaxFP, rewriter);
         return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
       }
 
@@ -664,17 +663,17 @@ static Value createLinalgBodyCalculationForElementwiseOp(
                            .getSExtValue()) +
                        1.0f));
 
-      auto intMax = rewriter.create<arith::ConstantOp>(
+      auto intMin = rewriter.create<arith::ConstantOp>(
           loc, rewriter.getIntegerAttr(
                    getElementTypeOrSelf(dstTy),
-                   APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
+                   APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
       auto minClampedFP =
           rewriter.create<arith::MaximumFOp>(loc, rounded, intMinFP);
       auto minClamped =
           rewriter.create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
       auto overflow = rewriter.create<arith::CmpFOp>(
           loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
-      return rewriter.create<arith::SelectOp>(loc, overflow, intMax,
+      return rewriter.create<arith::SelectOp>(loc, overflow, intMin,
                                               minClamped);
     }
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 6ca260a5324a9..a10053c31a8e6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -541,13 +541,13 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
 
   // CHECK: linalg.generic
   // CHECK: [[ROUND:%.+]] = math.roundeven {{%.+}} : f32
-  // CHECK: [[CSTMIN:%.+]] = arith.constant -2.14748365E+9 : f32
+  // CHECK: [[CSTMINF:%.+]] = arith.constant -2.14748365E+9 : f32
   // CHECK: [[CSTMAXP1:%.+]] = arith.constant 2.14748365E+9 : f32
-  // CHECK: [[CSTMAX:%.+]] = arith.constant 2147483647 : i32
-  // CHECK: [[MAX:%.+]] = arith.maximumf [[ROUND]], [[CSTMIN]] : f32
+  // CHECK: [[CSTMIN:%.+]] = arith.constant -2147483648 : i32
+  // CHECK: [[MAX:%.+]] = arith.maximumf [[ROUND]], [[CSTMINF]] : f32
   // CHECK: [[CONV:%.+]] = arith.fptosi [[MAX]] : f32 to i32
   // CHECK: [[CMP:%.+]] = arith.cmpf uge, [[ROUND]], [[CSTMAXP1]] : f32
-  // CHECK: arith.select [[CMP]], [[CSTMAX]], [[CONV]] : i32
+  // CHECK: arith.select [[CMP]], [[CSTMIN]], [[CONV]] : i32
   %20 = tosa.cast %0 : (tensor<1xf32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
@@ -591,7 +591,9 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
   // CHECK: [[ROUND:%.+]] = math.roundeven {{%.+}} : f16
   // CHECK: [[CSTMIN:%.+]] = arith.constant -1.280000e+02 : f16
   // CHECK: [[CSTMAX:%.+]] = arith.constant 1.270000e+02 : f16
-  // CHECK: [[MIN:%.+]] = arith.minimumf [[ROUND]], [[CSTMAX]] : f16
+  // CHECK: [[OVERFLOW:%.+]] = arith.cmpf ugt, [[ROUND]], [[CSTMAX]] : f16
+  // CHECK: [[CLAMPMAX:%.+]] = arith.select [[OVERFLOW]], [[CSTMIN]], [[ROUND]] : f16
+  // CHECK: [[MIN:%.+]] = arith.minimumf [[CLAMPMAX]], [[CSTMAX]] : f16
   // CHECK: [[CLAMP:%.+]] = arith.maximumf [[MIN]], [[CSTMIN]] : f16
   // CHECK: arith.fptosi [[CLAMP]] : f16 to i8
   %1 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi8>
@@ -604,8 +606,7 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
   // CHECK: [[OVERFLOW:%.+]] = arith.cmpf ueq, [[ROUND]], [[POSINF]] : f16
   // CHECK: [[UNDERFLOW:%.+]] = arith.cmpf ueq, [[ROUND]], [[NEGINF]] : f16
   // CHECK: [[MININT:%.+]] = arith.constant -2147483648 : i32
-  // CHECK: [[MAXINT:%.+]] = arith.constant 2147483647 : i32
-  // CHECK: [[CLAMPPOSINF:%.+]] = arith.select [[OVERFLOW]], [[MAXINT]], [[CONV]] : i32
+  // CHECK: [[CLAMPPOSINF:%.+]] = arith.select [[OVERFLOW]], [[MININT]], [[CONV]] : i32
   // CHECK: arith.select [[UNDERFLOW]], [[MININT]], [[CLAMPPOSINF]] : i32
   %2 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi32>
   return
@@ -1980,11 +1981,11 @@ func.func @test_dynamic_fft2d(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>
 // CHECK:             %[[ROUND_EVEN:.*]] = math.roundeven %[[IN]] : f32
 // CHECK:             %[[FP_INT_MIN:.*]] = arith.constant -9.22337203E+18 : f32
 // CHECK:             %[[FP_INT_MAX_PLUS_ONE:.*]] = arith.constant 9.22337203E+18 : f32
-// CHECK:             %[[INT_MAX:.*]] = arith.constant 9223372036854775807 : i64
+// CHECK:             %[[INT_MIN:.*]] = arith.constant -9223372036854775808 : i64
 // CHECK:             %[[MAX:.*]] = arith.maximumf %[[ROUND_EVEN]], %[[FP_INT_MIN]] : f32
 // CHECK:             %[[FPTOSI:.*]] = arith.fptosi %[[MAX]] : f32 to i64
 // CHECK:             %[[CMPF:.*]] = arith.cmpf uge, %[[ROUND_EVEN]], %[[FP_INT_MAX_PLUS_ONE]] : f32
-// CHECK:             %[[SELECT:.*]] = arith.select %[[CMPF]], %[[INT_MAX]], %[[FPTOSI]] : i64
+// CHECK:             %[[SELECT:.*]] = arith.select %[[CMPF]], %[[INT_MIN]], %[[FPTOSI]] : i64
 // CHECK:             linalg.yield %[[SELECT]] : i64
 // CHECK:           } -> tensor<1xi64>
 // CHECK:           return %[[RESULT]] : tensor<1xi64>

@llvmbot
Copy link
Member

llvmbot commented Mar 6, 2025

@llvm/pr-subscribers-mlir

Author: Alaa Ali (alaa-ali)

Changes

This PR fixes an issue related to integer overflow when casting a value greater than or equals
float MAX from float32 or float64 to int32 or int64 during tosa-to-linalg pass for tosa.cast.

This issue was found while debugging a numerical mismatch between tf.cast and tfl.cast.
tfl.cast is lowered to tosa.cast that casts between these types. The expected values were also confirmed in PyTorch using torch.Tensor.to to cast between similar dtypes and we chose to fix this overflow issue in order to match the results with Tensorflow casting and PyTorch casting.

Example of casting F64 min / max value to I64:
EXPECTED (tf.cast results):
[-9223372036854775808, -9223372036854775808]
FOUND (tosa.cast results):
[-9223372036854775808, 9223372036854775807]

Example of casting F32 min / max value to I32:
EXPECTED (tf.cast results):
[-2147483648, -2147483648 ]
FOUND (tosa.cast results):
[-2147483648, 2147483647 ]


Full diff: https://github.com/llvm/llvm-project/pull/130116.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+8-9)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+10-9)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 8732ddafa24d4..8085f1104a4cb 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -618,12 +618,8 @@ static Value createLinalgBodyCalculationForElementwiseOp(
             loc, rewriter.getIntegerAttr(
                      getElementTypeOrSelf(dstTy),
                      APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
-        auto intMax = rewriter.create<arith::ConstantOp>(
-            loc, rewriter.getIntegerAttr(
-                     getElementTypeOrSelf(dstTy),
-                     APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
         auto maxClamped =
-            rewriter.create<arith::SelectOp>(loc, overflow, intMax, conv);
+            rewriter.create<arith::SelectOp>(loc, overflow, intMin, conv);
         return rewriter.create<arith::SelectOp>(loc, underflow, intMin,
                                                 maxClamped);
       }
@@ -647,8 +643,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
                      APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
                          .getSExtValue()));
 
+        auto overflow = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, rounded, intMaxFP);
+        Value maxClampedFP = rewriter.create<arith::SelectOp>(loc, overflow, intMinFP, rounded);
+
         Value clamped =
-            clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
+            clampFloatHelper(loc, maxClampedFP, intMinFP, intMaxFP, rewriter);
         return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
       }
 
@@ -664,17 +663,17 @@ static Value createLinalgBodyCalculationForElementwiseOp(
                            .getSExtValue()) +
                        1.0f));
 
-      auto intMax = rewriter.create<arith::ConstantOp>(
+      auto intMin = rewriter.create<arith::ConstantOp>(
           loc, rewriter.getIntegerAttr(
                    getElementTypeOrSelf(dstTy),
-                   APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
+                   APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
       auto minClampedFP =
           rewriter.create<arith::MaximumFOp>(loc, rounded, intMinFP);
       auto minClamped =
           rewriter.create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
       auto overflow = rewriter.create<arith::CmpFOp>(
           loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
-      return rewriter.create<arith::SelectOp>(loc, overflow, intMax,
+      return rewriter.create<arith::SelectOp>(loc, overflow, intMin,
                                               minClamped);
     }
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 6ca260a5324a9..a10053c31a8e6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -541,13 +541,13 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
 
   // CHECK: linalg.generic
   // CHECK: [[ROUND:%.+]] = math.roundeven {{%.+}} : f32
-  // CHECK: [[CSTMIN:%.+]] = arith.constant -2.14748365E+9 : f32
+  // CHECK: [[CSTMINF:%.+]] = arith.constant -2.14748365E+9 : f32
   // CHECK: [[CSTMAXP1:%.+]] = arith.constant 2.14748365E+9 : f32
-  // CHECK: [[CSTMAX:%.+]] = arith.constant 2147483647 : i32
-  // CHECK: [[MAX:%.+]] = arith.maximumf [[ROUND]], [[CSTMIN]] : f32
+  // CHECK: [[CSTMIN:%.+]] = arith.constant -2147483648 : i32
+  // CHECK: [[MAX:%.+]] = arith.maximumf [[ROUND]], [[CSTMINF]] : f32
   // CHECK: [[CONV:%.+]] = arith.fptosi [[MAX]] : f32 to i32
   // CHECK: [[CMP:%.+]] = arith.cmpf uge, [[ROUND]], [[CSTMAXP1]] : f32
-  // CHECK: arith.select [[CMP]], [[CSTMAX]], [[CONV]] : i32
+  // CHECK: arith.select [[CMP]], [[CSTMIN]], [[CONV]] : i32
   %20 = tosa.cast %0 : (tensor<1xf32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
@@ -591,7 +591,9 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
   // CHECK: [[ROUND:%.+]] = math.roundeven {{%.+}} : f16
   // CHECK: [[CSTMIN:%.+]] = arith.constant -1.280000e+02 : f16
   // CHECK: [[CSTMAX:%.+]] = arith.constant 1.270000e+02 : f16
-  // CHECK: [[MIN:%.+]] = arith.minimumf [[ROUND]], [[CSTMAX]] : f16
+  // CHECK: [[OVERFLOW:%.+]] = arith.cmpf ugt, [[ROUND]], [[CSTMAX]] : f16
+  // CHECK: [[CLAMPMAX:%.+]] = arith.select [[OVERFLOW]], [[CSTMIN]], [[ROUND]] : f16
+  // CHECK: [[MIN:%.+]] = arith.minimumf [[CLAMPMAX]], [[CSTMAX]] : f16
   // CHECK: [[CLAMP:%.+]] = arith.maximumf [[MIN]], [[CSTMIN]] : f16
   // CHECK: arith.fptosi [[CLAMP]] : f16 to i8
   %1 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi8>
@@ -604,8 +606,7 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
   // CHECK: [[OVERFLOW:%.+]] = arith.cmpf ueq, [[ROUND]], [[POSINF]] : f16
   // CHECK: [[UNDERFLOW:%.+]] = arith.cmpf ueq, [[ROUND]], [[NEGINF]] : f16
   // CHECK: [[MININT:%.+]] = arith.constant -2147483648 : i32
-  // CHECK: [[MAXINT:%.+]] = arith.constant 2147483647 : i32
-  // CHECK: [[CLAMPPOSINF:%.+]] = arith.select [[OVERFLOW]], [[MAXINT]], [[CONV]] : i32
+  // CHECK: [[CLAMPPOSINF:%.+]] = arith.select [[OVERFLOW]], [[MININT]], [[CONV]] : i32
   // CHECK: arith.select [[UNDERFLOW]], [[MININT]], [[CLAMPPOSINF]] : i32
   %2 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi32>
   return
@@ -1980,11 +1981,11 @@ func.func @test_dynamic_fft2d(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>
 // CHECK:             %[[ROUND_EVEN:.*]] = math.roundeven %[[IN]] : f32
 // CHECK:             %[[FP_INT_MIN:.*]] = arith.constant -9.22337203E+18 : f32
 // CHECK:             %[[FP_INT_MAX_PLUS_ONE:.*]] = arith.constant 9.22337203E+18 : f32
-// CHECK:             %[[INT_MAX:.*]] = arith.constant 9223372036854775807 : i64
+// CHECK:             %[[INT_MIN:.*]] = arith.constant -9223372036854775808 : i64
 // CHECK:             %[[MAX:.*]] = arith.maximumf %[[ROUND_EVEN]], %[[FP_INT_MIN]] : f32
 // CHECK:             %[[FPTOSI:.*]] = arith.fptosi %[[MAX]] : f32 to i64
 // CHECK:             %[[CMPF:.*]] = arith.cmpf uge, %[[ROUND_EVEN]], %[[FP_INT_MAX_PLUS_ONE]] : f32
-// CHECK:             %[[SELECT:.*]] = arith.select %[[CMPF]], %[[INT_MAX]], %[[FPTOSI]] : i64
+// CHECK:             %[[SELECT:.*]] = arith.select %[[CMPF]], %[[INT_MIN]], %[[FPTOSI]] : i64
 // CHECK:             linalg.yield %[[SELECT]] : i64
 // CHECK:           } -> tensor<1xi64>
 // CHECK:           return %[[RESULT]] : tensor<1xi64>

@lhutton1 lhutton1 changed the title tosa.cast: fix answer mismatch to cast f64/f32 max value to i64/i32 [mlir][tosa][tosa-to-linalg] tosa.cast: fix answer mismatch to cast f64/f32 max value to i64/i32 Mar 6, 2025
@Jerry-Ge
Copy link
Member

Jerry-Ge commented Mar 6, 2025

Hi @alaa-ali , thanks for the PR! Could you first clear the code formatting errors?

@Jerry-Ge Jerry-Ge requested a review from FranklandJack March 6, 2025 22:14
@alaa-ali
Copy link
Contributor Author

alaa-ali commented Mar 7, 2025

Hi @alaa-ali , thanks for the PR! Could you first clear the code formatting errors?

Done. Thank you

@mgehre-amd
Copy link
Contributor

A difference with tensorflow isn't enough reason to change the TOSA implementation. Are the new semantics mandated by the TOSA spec?
It says

Casting from floating-point to integer:
...
- Result overflows must be saturated.

and the pseudo implementation is

out = truncate<out_t>(apply_clip_s<i32_t>(round_to_nearest_int(in), minimum_s<out_t>(), maximum_s<out_t>()));

which tells me that the current behavior of casting F64 max to i64 9223372036854775807 is correct.

@sahas3 sahas3 requested review from GeorgeARM and sjarus March 7, 2025 14:04
@sahas3
Copy link
Member

sahas3 commented Mar 7, 2025

A difference with tensorflow isn't enough reason to change the TOSA implementation. Are the new semantics mandated by the TOSA spec? It says

Casting from floating-point to integer:
...
- Result overflows must be saturated.

and the pseudo implementation is

out = truncate<out_t>(apply_clip_s<i32_t>(round_to_nearest_int(in), minimum_s<out_t>(), maximum_s<out_t>()));

which tells me that the current behavior of casting F64 max to i64 9223372036854775807 is correct.

That's a fair point. I suppose to match numeric, the fix should be at the tfl-to-tosa and torch-to-tosa passes instead by adding logic similar to the change in this PR after emitting tosa.cast in those passes? Added @sjarus and @GeorgeARM to get their thoughts on this as well.

@ivangarcia44
Copy link
Contributor

ivangarcia44 commented Mar 7, 2025

Result overflows must be saturated.

It would be nice if TOSA cast operation matches the behavior of PyTorch and Tensorflow since they are the most popular deep learning frameworks. It would make the lowering between them smoother, less complex, and reduce the risk of bugs.

Maybe a solution to make all front ends happy could be to add an attribute to the TOSA cast operation to determine if saturation is done on overflow or not. This would facilitate the integration with any front end.

@Hanumanth04
Copy link

Hanumanth04 commented Mar 7, 2025

A difference with tensorflow isn't enough reason to change the TOSA implementation. Are the new semantics mandated by the TOSA spec? It says

Casting from floating-point to integer:
...
- Result overflows must be saturated.

and the pseudo implementation is

out = truncate<out_t>(apply_clip_s<i32_t>(round_to_nearest_int(in), minimum_s<out_t>(), maximum_s<out_t>()));

which tells me that the current behavior of casting F64 max to i64 9223372036854775807 is correct.

My two cents:
According to the C++ specification (please see the Floating-integral conversion section in the below link), behavior is undefined for the cases mentioned in the solution description. I understand that the TOSA specification went with saturation behavior in this case. However, when comparing translated output numeric, we generally compare it with PyTorch and TensorFlow output. So, probably making TOSA Cast specification comply with PyTorch and TensorFlow will help here. At least, I see how this can be beneficial while comparing single precision numerics.

https://en.cppreference.com/w/cpp/language/implicit_conversion

@alaa-ali
Copy link
Contributor Author

alaa-ali commented Mar 7, 2025

Result overflows must be saturated.

It would be nice if TOSA cast operation matches the behavior of PyTorch and Tensorflow since they are the most popular deep learning frameworks. It would make the lowering between them smoother, less complex, and reduce the risk of bugs.

Maybe a solution to make all front ends happy could be to add an attribute to the TOSA cast operation to determine if saturation is done on overflow or not. This would facilitate the integration with any front end.


Adding a new attribute to select the TOSA cast operation (Default: the saturation behavior from the TOSA spec) seems a reasonable idea to resolve this and provide a solution for multiple use cases.

@eric-k256
Copy link
Contributor

You're comparing against a single backend of TensorFlow and PyTorch in these cases, I don't see a guarantee that other backends would have the same behavior, which would make the change suggested here mismatch those implementations.

Relying on undefined behavior in your testing seems dangerous, and is what appears to be happening here. Is there a practical use for changing the TOSA behavior that would be helpful for real world use cases?

@ivangarcia44
Copy link
Contributor

You're comparing against a single backend of TensorFlow and PyTorch in these cases, I don't see a guarantee that other backends would have the same behavior, which would make the change suggested here mismatch those implementations.

Relying on undefined behavior in your testing seems dangerous, and is what appears to be happening here. Is there a practical use for changing the TOSA behavior that would be helpful for real world use cases?

Since PyTorch/Tensorflow don't saturate on overflow, then all hardware backends should do the same for numerical equivalence.

Even if this is not the case, having an attribute on the TOSA cast operation to control the saturation behavior should not hurt. If the default value of this attribute is to saturate, then nothing should be broken. It is up to the user to decide what works best for their infrastructure.

@Endilll Endilll removed their request for review March 8, 2025 15:12
@sahas3
Copy link
Member

sahas3 commented Mar 8, 2025

Re-reading the TOSA spec I see:

The intent is to enable a variety of implementations running on a diverse range of processors, with the results at the TOSA level consistent across those implementations. .... Most operators from the common ML frameworks (TensorFlow, PyTorch, etc.) should be expressible in TOSA. It is expected that there will be tools to lower from ML frameworks into TOSA."

So it seems matching the behavior of ML frameworks like TF/PyTorch exactly was never the goal for TOSA spec. It only serves as an intermediary IR to be able to target various HW backends from the ML frameworks. Interestingly the behavior of TOSA is inline with what happens in JAX:

>>>import jax.numpy as jnp
>>> max_float32 = jnp.finfo(jnp.float32).max
>>> print(max_float32)
3.4028235e+38
>>> jnp.astype(max_float32, jnp.int32)
Array(2147483647, dtype=int32)

While I understand that having an option to match TOSA spec to that of the different ML frameworks will be nice to have such that there is a single implementation that does the right thing when lowering from TF/PyTorch to TOSA and subsequently to Linalg, it's hard to argue why that onus should be on TOSA spec instead of the individual xxx-to-tosa passes.

@GeorgeARM
Copy link
Contributor

Is rather strange reasoning about correctness on the boundaries of undefined behavior.
Truth be told think the explicit behavior that TOSA has is the "safest" one and easiest to deploy against validation flows. Having different behaviours will mean that the legalization in order to produce "correct" results they will need to be tuned (unless I am missing something) for the specific backend which really beats the purpose of TOSA.
Echoing @eric-k256 question, are there real use-cases that requires this wrap-around behaviour or is it that synthetic validation tests fail?

@ivangarcia44
Copy link
Contributor

Is rather strange reasoning about correctness on the boundaries of undefined behavior. Truth be told think the explicit behavior that TOSA has is the "safest" one and easiest to deploy against validation flows. Having different behaviours will mean that the legalization in order to produce "correct" results they will need to be tuned (unless I am missing something) for the specific backend which really beats the purpose of TOSA. Echoing @eric-k256 question, are there real use-cases that requires this wrap-around behaviour or is it that synthetic validation tests fail?

I don't know of any backend use case. It seemed to me that it could be a low hanging fruit guarding the overflow saturation logic in the TOSA cast operation with a compile-time attribute to make the PyTorch/Tensorflow lowering to TOSA simpler. I have seen such attributes for cast operations in other frameworks that makes the infrastructure easier, but its ok not to do it.

@sjarus sjarus requested a review from eric-k256 March 12, 2025 14:45
Copy link
Contributor

@eric-k256 eric-k256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this should go in. It is different than what the TOSA spec behavior is, and there isn't anything in the code explaining why a value greater than the maximum representable should be output as the lowest output value.

You could propose a change to the TOSA spec, but adding an attribute to choose behavior means that both paths would need to be tested. I don't expect this to be an issue with real world uses, as relying on undefined behavior is dangerous.

If it was important to match the dialect behavior, you could check for overflow before the CAST and then do a SELECT for the appropriate value after the CAST. That way you could get the framework behavior and TOSA doesn't need to change.

@Jerry-Ge
Copy link
Member

Should we close this PR now? @alaa-ali

@alaa-ali
Copy link
Contributor Author

Should we close this PR now? @alaa-ali

@Jerry-Ge Yes. we can close it. We will proceed to change this in tfl-to-tosa pipeline. Thank you.

@Jerry-Ge Jerry-Ge closed this Mar 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants