Skip to content

Conversation

@tatwaichong
Copy link
Contributor

Adds a concept of EXT-DOUBLEROUND and EXT-INEXACTROUND
to the dialect. It also converts the "double_round" attribute on rescale
to a string type "rounding_mode" attribute with the following options:
"DOUBLE_ROUND", "SINGLE_ROUND", "INEXACT_ROUND".

The validation pass has been updated to ensure "DOUBLE_ROUND"
and "INEXACT_ROUND" are only valid when their extensions are
available.

Finally, lowerings to arith and linalg have been updated such that a
lowering for "INEXACT_ROUND" is not currently supported.

@llvmbot
Copy link
Member

llvmbot commented Mar 7, 2025

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: TatWai Chong (tatwaichong)

Changes

Adds a concept of EXT-DOUBLEROUND and EXT-INEXACTROUND
to the dialect. It also converts the "double_round" attribute on rescale
to a string type "rounding_mode" attribute with the following options:
"DOUBLE_ROUND", "SINGLE_ROUND", "INEXACT_ROUND".

The validation pass has been updated to ensure "DOUBLE_ROUND"
and "INEXACT_ROUND" are only valid when their extensions are
available.

Finally, lowerings to arith and linalg have been updated such that a
lowering for "INEXACT_ROUND" is not currently supported.


Patch is 63.31 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/130337.diff

25 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+6-1)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+1-1)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h (+2)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+7)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td (+1-1)
  • (modified) mlir/lib/Conversion/TosaToArith/TosaToArith.cpp (+12-2)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+10-5)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+30-1)
  • (added) mlir/test/Conversion/TosaToArith/tosa-to-arith-invalid.mlir (+8)
  • (modified) mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir (+4-4)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir (+1-1)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+1-1)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+25-15)
  • (modified) mlir/test/Dialect/Tosa/availability.mlir (+1-1)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+1-1)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+20-19)
  • (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+30)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+1-1)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+2-2)
  • (modified) mlir/test/Dialect/Tosa/profile_all_unsupported.mlir (+1-1)
  • (modified) mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir (+1-1)
  • (modified) mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir (+12-1)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+1-1)
  • (modified) mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp (+1-1)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index f2328003e49c5..db725dbd5e1bf 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -226,6 +226,8 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic,
 // FFT          : Fast Fourier Transform operations.
 // VARIABLE     : Stateful variable operations.
 // CONTROLFLOW  : Control Flow operations.
+// DOUBLEROUND  : Adds double rounding support to the RESCALE operator.
+// INEXACTROUND : Adds inexact rounding support to the RESCALE operator.
 //===----------------------------------------------------------------------===//
 
 def Tosa_NONE : I32EnumAttrCase<"none", 0>;
@@ -241,11 +243,14 @@ def Tosa_EXT_FP8E5M2      : I32EnumAttrCase<"fp8e5m2", 5>;
 def Tosa_EXT_FFT          : I32EnumAttrCase<"fft", 6>;
 def Tosa_EXT_VARIABLE     : I32EnumAttrCase<"variable", 7>;
 def Tosa_EXT_CONTROLFLOW  : I32EnumAttrCase<"controlflow", 8>;
+def Tosa_EXT_DOUBLEROUND  : I32EnumAttrCase<"doubleround", 9>;
+def Tosa_EXT_INEXACTROUND : I32EnumAttrCase<"inexactround", 10>;
 
 def Tosa_ExtensionAttr
     : Tosa_I32EnumAttr<"Extension", "supported TOSA extensions", "ext", [
       Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16, Tosa_EXT_FP8E4M3,
-      Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW, Tosa_EXT_NONE
+      Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW,
+      Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND, Tosa_EXT_NONE
     ]>;
 
 def Tosa_ExtensionArrayAttr
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 5340ce52d73fc..3f87e299cbbdd 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2310,7 +2310,7 @@ def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
     I32Attr:$input_zp,
     I32Attr:$output_zp,
     BoolAttr:$scale32,
-    BoolAttr:$double_round,
+    Tosa_RoundingTypeAttr:$rounding_mode,
     BoolAttr:$per_channel,
     BoolAttr: $input_unsigned,
     BoolAttr: $output_unsigned
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index 969d06afc70d6..88f454f63e6f9 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -136,6 +136,8 @@ class TosaProfileCompliance {
     switch (ext) {
     case Extension::int16:
     case Extension::int4:
+    case Extension::doubleround:
+    case Extension::inexactround:
       return {Profile::pro_int};
     case Extension::bf16:
     case Extension::fp8e4m3:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 08c0e02139b0c..0038d8c386ca7 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -247,6 +247,13 @@ def Tosa_NanPropagationAttr : StringBasedAttr<
           "::llvm::cast<StringAttr>($_self).getValue() == \"IGNORE\"">,
     "Supported NaN propagation strategies">;
 
+// Rounding mode for tosa.rescale
+def Tosa_RoundingTypeAttr : StringBasedAttr<
+    CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"SINGLE_ROUND\"  || " #
+          "::llvm::cast<StringAttr>($_self).getValue() == \"INEXACT_ROUND\" || " #
+          "::llvm::cast<StringAttr>($_self).getValue() == \"DOUBLE_ROUND\"">,
+    "Supported rounding modes">;
+
 def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;
 
 // Tensor to buffer types.
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
index 8756cb9e5de3a..8a27e5ba39331 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
@@ -44,7 +44,7 @@ def Tosa_ApplyScaleOp :
     Tosa_IntLike:$value,
     Tosa_IntLike:$multiplier,
     Tosa_Int8Like:$shift,
-    BoolAttr:$double_round
+    Tosa_RoundingTypeAttr:$rounding_mode
   );
 
   let results = (outs
diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
index 5c84b4063da2e..9dea12355a519 100644
--- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
+++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
@@ -65,6 +65,11 @@ class ApplyScaleGenericOpConverter
 
   LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
                                 PatternRewriter &rewriter) const final {
+    StringRef roundingMode = op.getRoundingMode();
+    if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
+      return failure();
+    }
+
     Location loc = op.getLoc();
     Value value = op.getValue();
     Value multiplier32 = op.getMultiplier();
@@ -96,7 +101,7 @@ class ApplyScaleGenericOpConverter
     multiply64 = rewriter.create<arith::AddIOp>(loc, multiply64, round);
 
     // Apply double rounding if necessary.
-    if (op.getDoubleRound()) {
+    if (op.getRoundingMode() == "DOUBLE_ROUND") {
       int64_t roundInt = 1 << 30;
       Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
       Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
@@ -125,6 +130,11 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
 
   LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
                                 PatternRewriter &rewriter) const final {
+    StringRef roundingMode = op.getRoundingMode();
+    if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
+      return failure();
+    }
+
     Location loc = op.getLoc();
 
     Type resultTy = op.getType();
@@ -170,7 +180,7 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
         rewriter.create<arith::SelectOp>(loc, shiftOver32, shiftHighR, zero32);
 
     // Conditionally perform our double round.
-    if (op.getDoubleRound()) {
+    if (op.getRoundingMode() == "DOUBLE_ROUND") {
       Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
       Value valuePositive = rewriter.create<arith::CmpIOp>(
           loc, arith::CmpIPredicate::sge, value32, zero32);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index f7dd33c7e8b53..b59e55302a60c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -170,7 +170,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
 
         auto result = rewriter.create<tosa::ApplyScaleOp>(
             loc, rewriter.getI32Type(), a, b, shiftConst,
-            rewriter.getBoolAttr(false));
+            rewriter.getStringAttr("SINGLE_ROUND"));
 
         if (elementTy.isInteger(32))
           return result;
@@ -1374,7 +1374,10 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
     unsigned rank = inputTy.getRank();
 
     // This is an illegal configuration. terminate and log an error
-    if (op.getDoubleRound() && !op.getScale32())
+    if (op.getRoundingMode() == "INEXACT_ROUND")
+      return rewriter.notifyMatchFailure(
+          op, "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not currently supported");
+    if (op.getRoundingMode() == "DOUBLE_ROUND" && !op.getScale32())
       return rewriter.notifyMatchFailure(
           op, "tosa.rescale requires scale32 for double_round to be true");
 
@@ -1418,9 +1421,12 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
 
     // Double round only occurs if shift is greater than 31, check that this
     // is ever true.
+
     bool doubleRound =
-        op.getDoubleRound() &&
+        op.getRoundingMode() == "DOUBLE_ROUND" &&
         llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
+    StringAttr roundingMode = doubleRound ? rewriter.getStringAttr("DOUBLE_ROUND") :
+        rewriter.getStringAttr("SINGLE_ROUND");
 
     SmallVector<AffineMap> indexingMaps = {
         rewriter.getMultiDimIdentityMap(rank)};
@@ -1515,8 +1521,7 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
               nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
 
           value = nestedBuilder.create<tosa::ApplyScaleOp>(
-              loc, nestedBuilder.getI32Type(), value, multiplier, shift,
-              nestedBuilder.getBoolAttr(doubleRound));
+              loc, nestedBuilder.getI32Type(), value, multiplier, shift, roundingMode);
 
           // Move to the new zero-point.
           value =
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 2a2589e19d0ac..2dd3d2fb3325d 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -1005,7 +1005,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
                 rewriter
                     .create<tosa::ApplyScaleOp>(loc, rewriter.getI32Type(),
                                                 poolVal, multiplier, shift,
-                                                rewriter.getBoolAttr(false))
+                                                rewriter.getStringAttr("SINGLE_ROUND"))
                     .getResult();
 
             // If we have quantization information we need to apply output
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index b8604ef40cc93..70c4cd0a526cd 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -104,6 +104,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
   }
 
   LogicalResult applyLevelCheck(Operation *op);
+  LogicalResult applyAttributeCheck(Operation *op);
 
   // check variable read/write data types against variable declarations
   LogicalResult applyVariableCheck(Operation *op);
@@ -386,6 +387,23 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     return true;
   }
 
+  bool attributeCheckRescale(Operation *op) {
+    if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
+      if (rescale.getRoundingMode() == "DOUBLE_ROUND" &&
+          !targetEnv.allows(Extension::doubleround)) {
+        op->emitOpError() << "failed attribute check: rounding_mode = DOUBLE_ROUND "
+                          << "requires extension [doubleround]";
+        return false;
+      } else if (rescale.getRoundingMode() == "INEXACT_ROUND" &&
+          !targetEnv.allows(Extension::inexactround)) {
+        op->emitOpError() << "failed attribute check: rounding_mode = INEXACT_ROUND "
+                          << "requires extension [inexactround]";
+        return false;
+      }
+    }
+    return true;
+  }
+
   // configure profile and level values from pass options profileName and
   // levelName
   void configLevelAndProfile() {
@@ -415,7 +433,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
         } else {
           llvm::errs() << "unknown TOSA extension name passed in: " << ext
                        << ", supported extension are int16, int4, bf16, "
-                       << "fp8e4m3, fp8e5m2, fft, variable and controlflow\n";
+                       << "fp8e4m3, fp8e5m2, fft, variable, controlflow, "
+                       << "doubleround and inexactround\n";
           return signalPassFailure();
         }
       }
@@ -642,6 +661,12 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
   return success();
 }
 
+LogicalResult TosaValidation::applyAttributeCheck(Operation *op) {
+  if (!attributeCheckRescale(op))
+    return failure();
+  return success();
+}
+
 inline bool CompatibleTypes(const mlir::Type &type,
                             const mlir::Type &declaredType) {
   // for now, simply use type equality comparison
@@ -936,6 +961,10 @@ void TosaValidation::runOnOperation() {
     if (failed(applyLevelCheck(op)))
       signalPassFailure();
 
+    // check additional attribute restrictions
+    if (failed(applyAttributeCheck(op)))
+      signalPassFailure();
+
     // do variable type checks
     if (failed(applyVariableCheck(op)))
       signalPassFailure();
diff --git a/mlir/test/Conversion/TosaToArith/tosa-to-arith-invalid.mlir b/mlir/test/Conversion/TosaToArith/tosa-to-arith-invalid.mlir
new file mode 100644
index 0000000000000..4b324955439aa
--- /dev/null
+++ b/mlir/test/Conversion/TosaToArith/tosa-to-arith-invalid.mlir
@@ -0,0 +1,8 @@
+// RUN: mlir-opt --split-input-file --tosa-to-arith="include-apply-rescale=true use-32-bit=true" %s -verify-diagnostics
+
+// CHECK-LABEL: @apply_scale_unsupported_inexact_round
+func.func @apply_scale_unsupported_inexact_round(%arg0 : i64, %arg1 : i32, %arg2 : i8) -> (i32) {
+  // expected-error@+1 {{failed to legalize operation 'tosa.apply_scale'}}
+  %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "INEXACT_ROUND"} : (i64, i32, i8) -> i32
+  return %res : i32
+}
diff --git a/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
index 14f811727c456..db68ca40879f4 100644
--- a/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
+++ b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
@@ -67,7 +67,7 @@ func.func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
   // CHECK-DAG: %[[LOWALIGN:.+]] = arith.select %[[OVER31]], %[[C0]], %[[LOR]]
   // CHECK-DAG: %[[RESULT:.+]] = arith.addi %[[LOWALIGN]], %[[HIALIGN]]
   // CHECK: return %[[RESULT]]
-  %res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i32, i32, i8) -> i32
+  %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "DOUBLE_ROUND"} : (i32, i32, i8) -> i32
   return %res : i32
 }
 
@@ -77,7 +77,7 @@ func.func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
 // SCALE: tosa.apply_scale
 func.func @apply_scale_test_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>, %arg2 : vector<4xi8>) -> (vector<4xi32>) {
   // CHECK-NOT: "tosa.apply_scale"
-  %res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (vector<4xi32>, vector<4xi32>, vector<4xi8>) -> vector<4xi32>
+  %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "DOUBLE_ROUND"} : (vector<4xi32>, vector<4xi32>, vector<4xi8>) -> vector<4xi32>
   return %res : vector<4xi32>
 }
 
@@ -115,7 +115,7 @@ func.func @apply_scale_test_i48(%arg0 : i48, %arg1 : i32, %arg2 : i8) -> (i32) {
   // CHECK-DAG: %[[SHR:.+]] = arith.shrsi %[[RES64]], %[[S64]]
   // CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i64 to i32
   // CHECK: return %[[TRUNC]]
-  %res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i48, i32, i8) -> i32
+  %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "DOUBLE_ROUND"} : (i48, i32, i8) -> i32
   return %res : i32
 }
 
@@ -152,6 +152,6 @@ func.func @apply_scale_test_i64(%arg0 : i64, %arg1 : i32, %arg2 : i8) -> (i32) {
   // CHECK-DAG: %[[SHR:.+]] = arith.shrsi %[[RES64]], %[[S64]]
   // CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i64 to i32
   // CHECK: return %[[TRUNC]]
-  %res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i64, i32, i8) -> i32
+  %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "DOUBLE_ROUND"} : (i64, i32, i8) -> i32
   return %res : i32
 }
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 77687b83e5e3c..54c6ed994e947 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -36,7 +36,7 @@ func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32,
   %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
   // expected-error@+1 {{failed to legalize operation 'tosa.rescale'}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
   return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
 }
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 5bb4a3bddb51b..a89da9a2b9fed 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -411,7 +411,7 @@ func.func @avg_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>) {
   // CHECK: %[[TRUNC_SHIFT:.+]] = arith.trunci %[[SUB]]
   // CHECK: %[[C30:.+]] = arith.constant 30
   // CHECK: %[[SHIFT:.+]] = arith.addi %[[TRUNC_SHIFT]], %[[C30]] : i8
-  // CHECK: %[[SCALED:.+]] = tosa.apply_scale %[[IN]], %[[TRUNC_MUL]], %[[SHIFT]] {double_round = false}
+  // CHECK: %[[SCALED:.+]] = tosa.apply_scale %[[IN]], %[[TRUNC_MUL]], %[[SHIFT]] {rounding_mode = "SINGLE_ROUND"}
 
   // Perform the normalization.
   // CHECK: %[[CMIN:.+]] = arith.constant -128
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index a3ed8c2805282..596069ad7f53d 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1141,7 +1141,7 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
   // CHECK: [[C22:%.+]] = arith.constant 22
   // CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
   // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
-  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {double_round = false}
+  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
   // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
   // CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
   // CHECK-DAG: [[CMAX:%.+]] = arith.constant 127
@@ -1151,7 +1151,7 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
   // CHECK-DAG: linalg.yield [[TRUNC]]
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
   %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
-  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
 
   // CHECK: return
   return
@@ -1172,7 +1172,7 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
   // CHECK: [[C22:%.+]] = arith.constant 22
   // CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
   // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
-  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {double_round = false}
+  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
   // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
   // CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
   // CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
@@ -1182,7 +1182,7 @@ func.func @rescale_i8_unsigned_output(%arg0 : t...
[truncated]

@github-actions
Copy link

github-actions bot commented Mar 7, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but I won't explicitly approve since I authored some of this patch

Copy link
Contributor

@GeorgeARM GeorgeARM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

This commit adds a concept of EXT-DOUBLEROUND and EXT-INEXACTROUND
to the dialect. It also converts the "double_round" attribute on rescale
to a string type "rounding_mode" attribute with the following options:
"DOUBLE_ROUND", "SINGLE_ROUND", "INEXACT_ROUND".

The validation pass has been updated to ensure "DOUBLE_ROUND" and
"INEXACT_ROUND" are only valid when their extensions are available.

Finally, lowerings to arith and linalg have been updated such that
a lowering for "INEXACT_ROUND" is not currently supported.

Co-authored-by: TatWai Chong <[email protected]>
@tatwaichong
Copy link
Contributor Author

Rebase.

@GeorgeARM GeorgeARM merged commit 3fb8cb6 into llvm:main Mar 10, 2025
9 of 10 checks passed
@tatwaichong tatwaichong deleted the ext_round branch March 19, 2025 17:59
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Mar 20, 2025
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 20, 2025
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Mar 21, 2025
Changes needed for llvm/llvm-project#130337,
llvm/llvm-project#129720,
and llvm/llvm-project#130340

FUTURE_COPYBARA_INTEGRATE_REVIEW=#23793 from Tixxx:tixxx/cp_sync 3a8a643
PiperOrigin-RevId: 738783395
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 21, 2025
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Mar 21, 2025
Changes needed for llvm/llvm-project#130337,
llvm/llvm-project#129720,
and llvm/llvm-project#130340

FUTURE_COPYBARA_INTEGRATE_REVIEW=#23793 from Tixxx:tixxx/cp_sync 3a8a643
PiperOrigin-RevId: 738783395
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 21, 2025
copybara-service bot pushed a commit to openxla/shardy that referenced this pull request Mar 21, 2025
copybara-service bot pushed a commit to tensorflow/mlir-hlo that referenced this pull request Mar 21, 2025
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 21, 2025
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Mar 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants