-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][tosa] Add support for EXT-DOUBLEROUND and EXT-INEXACTROUND #130337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: TatWai Chong (tatwaichong) ChangesAdds a concept of EXT-DOUBLEROUND and EXT-INEXACTROUND The validation pass has been updated to ensure "DOUBLE_ROUND" Finally, lowerings to arith and linalg have been updated such that a Patch is 63.31 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/130337.diff 25 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index f2328003e49c5..db725dbd5e1bf 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -226,6 +226,8 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic,
// FFT : Fast Fourier Transform operations.
// VARIABLE : Stateful variable operations.
// CONTROLFLOW : Control Flow operations.
+// DOUBLEROUND : Adds double rounding support to the RESCALE operator.
+// INEXACTROUND : Adds inexact rounding support to the RESCALE operator.
//===----------------------------------------------------------------------===//
def Tosa_NONE : I32EnumAttrCase<"none", 0>;
@@ -241,11 +243,14 @@ def Tosa_EXT_FP8E5M2 : I32EnumAttrCase<"fp8e5m2", 5>;
def Tosa_EXT_FFT : I32EnumAttrCase<"fft", 6>;
def Tosa_EXT_VARIABLE : I32EnumAttrCase<"variable", 7>;
def Tosa_EXT_CONTROLFLOW : I32EnumAttrCase<"controlflow", 8>;
+def Tosa_EXT_DOUBLEROUND : I32EnumAttrCase<"doubleround", 9>;
+def Tosa_EXT_INEXACTROUND : I32EnumAttrCase<"inexactround", 10>;
def Tosa_ExtensionAttr
: Tosa_I32EnumAttr<"Extension", "supported TOSA extensions", "ext", [
Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16, Tosa_EXT_FP8E4M3,
- Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW, Tosa_EXT_NONE
+ Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW,
+ Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND, Tosa_EXT_NONE
]>;
def Tosa_ExtensionArrayAttr
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 5340ce52d73fc..3f87e299cbbdd 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2310,7 +2310,7 @@ def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
I32Attr:$input_zp,
I32Attr:$output_zp,
BoolAttr:$scale32,
- BoolAttr:$double_round,
+ Tosa_RoundingTypeAttr:$rounding_mode,
BoolAttr:$per_channel,
BoolAttr: $input_unsigned,
BoolAttr: $output_unsigned
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index 969d06afc70d6..88f454f63e6f9 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -136,6 +136,8 @@ class TosaProfileCompliance {
switch (ext) {
case Extension::int16:
case Extension::int4:
+ case Extension::doubleround:
+ case Extension::inexactround:
return {Profile::pro_int};
case Extension::bf16:
case Extension::fp8e4m3:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 08c0e02139b0c..0038d8c386ca7 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -247,6 +247,13 @@ def Tosa_NanPropagationAttr : StringBasedAttr<
"::llvm::cast<StringAttr>($_self).getValue() == \"IGNORE\"">,
"Supported NaN propagation strategies">;
+// Rounding mode for tosa.rescale
+def Tosa_RoundingTypeAttr : StringBasedAttr<
+ CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"SINGLE_ROUND\" || " #
+ "::llvm::cast<StringAttr>($_self).getValue() == \"INEXACT_ROUND\" || " #
+ "::llvm::cast<StringAttr>($_self).getValue() == \"DOUBLE_ROUND\"">,
+ "Supported rounding modes">;
+
def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;
// Tensor to buffer types.
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
index 8756cb9e5de3a..8a27e5ba39331 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
@@ -44,7 +44,7 @@ def Tosa_ApplyScaleOp :
Tosa_IntLike:$value,
Tosa_IntLike:$multiplier,
Tosa_Int8Like:$shift,
- BoolAttr:$double_round
+ Tosa_RoundingTypeAttr:$rounding_mode
);
let results = (outs
diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
index 5c84b4063da2e..9dea12355a519 100644
--- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
+++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
@@ -65,6 +65,11 @@ class ApplyScaleGenericOpConverter
LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
PatternRewriter &rewriter) const final {
+ StringRef roundingMode = op.getRoundingMode();
+ if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
+ return failure();
+ }
+
Location loc = op.getLoc();
Value value = op.getValue();
Value multiplier32 = op.getMultiplier();
@@ -96,7 +101,7 @@ class ApplyScaleGenericOpConverter
multiply64 = rewriter.create<arith::AddIOp>(loc, multiply64, round);
// Apply double rounding if necessary.
- if (op.getDoubleRound()) {
+ if (op.getRoundingMode() == "DOUBLE_ROUND") {
int64_t roundInt = 1 << 30;
Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
@@ -125,6 +130,11 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
PatternRewriter &rewriter) const final {
+ StringRef roundingMode = op.getRoundingMode();
+ if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
+ return failure();
+ }
+
Location loc = op.getLoc();
Type resultTy = op.getType();
@@ -170,7 +180,7 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
rewriter.create<arith::SelectOp>(loc, shiftOver32, shiftHighR, zero32);
// Conditionally perform our double round.
- if (op.getDoubleRound()) {
+ if (op.getRoundingMode() == "DOUBLE_ROUND") {
Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
Value valuePositive = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, value32, zero32);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index f7dd33c7e8b53..b59e55302a60c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -170,7 +170,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
auto result = rewriter.create<tosa::ApplyScaleOp>(
loc, rewriter.getI32Type(), a, b, shiftConst,
- rewriter.getBoolAttr(false));
+ rewriter.getStringAttr("SINGLE_ROUND"));
if (elementTy.isInteger(32))
return result;
@@ -1374,7 +1374,10 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
unsigned rank = inputTy.getRank();
// This is an illegal configuration. terminate and log an error
- if (op.getDoubleRound() && !op.getScale32())
+ if (op.getRoundingMode() == "INEXACT_ROUND")
+ return rewriter.notifyMatchFailure(
+ op, "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not currently supported");
+ if (op.getRoundingMode() == "DOUBLE_ROUND" && !op.getScale32())
return rewriter.notifyMatchFailure(
op, "tosa.rescale requires scale32 for double_round to be true");
@@ -1418,9 +1421,12 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
// Double round only occurs if shift is greater than 31, check that this
// is ever true.
+
bool doubleRound =
- op.getDoubleRound() &&
+ op.getRoundingMode() == "DOUBLE_ROUND" &&
llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
+ StringAttr roundingMode = doubleRound ? rewriter.getStringAttr("DOUBLE_ROUND") :
+ rewriter.getStringAttr("SINGLE_ROUND");
SmallVector<AffineMap> indexingMaps = {
rewriter.getMultiDimIdentityMap(rank)};
@@ -1515,8 +1521,7 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
value = nestedBuilder.create<tosa::ApplyScaleOp>(
- loc, nestedBuilder.getI32Type(), value, multiplier, shift,
- nestedBuilder.getBoolAttr(doubleRound));
+ loc, nestedBuilder.getI32Type(), value, multiplier, shift, roundingMode);
// Move to the new zero-point.
value =
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 2a2589e19d0ac..2dd3d2fb3325d 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -1005,7 +1005,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
rewriter
.create<tosa::ApplyScaleOp>(loc, rewriter.getI32Type(),
poolVal, multiplier, shift,
- rewriter.getBoolAttr(false))
+ rewriter.getStringAttr("SINGLE_ROUND"))
.getResult();
// If we have quantization information we need to apply output
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index b8604ef40cc93..70c4cd0a526cd 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -104,6 +104,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
}
LogicalResult applyLevelCheck(Operation *op);
+ LogicalResult applyAttributeCheck(Operation *op);
// check variable read/write data types against variable declarations
LogicalResult applyVariableCheck(Operation *op);
@@ -386,6 +387,23 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
return true;
}
+ bool attributeCheckRescale(Operation *op) {
+ if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
+ if (rescale.getRoundingMode() == "DOUBLE_ROUND" &&
+ !targetEnv.allows(Extension::doubleround)) {
+ op->emitOpError() << "failed attribute check: rounding_mode = DOUBLE_ROUND "
+ << "requires extension [doubleround]";
+ return false;
+ } else if (rescale.getRoundingMode() == "INEXACT_ROUND" &&
+ !targetEnv.allows(Extension::inexactround)) {
+ op->emitOpError() << "failed attribute check: rounding_mode = INEXACT_ROUND "
+ << "requires extension [inexactround]";
+ return false;
+ }
+ }
+ return true;
+ }
+
// configure profile and level values from pass options profileName and
// levelName
void configLevelAndProfile() {
@@ -415,7 +433,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
} else {
llvm::errs() << "unknown TOSA extension name passed in: " << ext
<< ", supported extension are int16, int4, bf16, "
- << "fp8e4m3, fp8e5m2, fft, variable and controlflow\n";
+ << "fp8e4m3, fp8e5m2, fft, variable, controlflow, "
+ << "doubleround and inexactround\n";
return signalPassFailure();
}
}
@@ -642,6 +661,12 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
return success();
}
+LogicalResult TosaValidation::applyAttributeCheck(Operation *op) {
+ if (!attributeCheckRescale(op))
+ return failure();
+ return success();
+}
+
inline bool CompatibleTypes(const mlir::Type &type,
const mlir::Type &declaredType) {
// for now, simply use type equality comparison
@@ -936,6 +961,10 @@ void TosaValidation::runOnOperation() {
if (failed(applyLevelCheck(op)))
signalPassFailure();
+ // check additional attribute restrictions
+ if (failed(applyAttributeCheck(op)))
+ signalPassFailure();
+
// do variable type checks
if (failed(applyVariableCheck(op)))
signalPassFailure();
diff --git a/mlir/test/Conversion/TosaToArith/tosa-to-arith-invalid.mlir b/mlir/test/Conversion/TosaToArith/tosa-to-arith-invalid.mlir
new file mode 100644
index 0000000000000..4b324955439aa
--- /dev/null
+++ b/mlir/test/Conversion/TosaToArith/tosa-to-arith-invalid.mlir
@@ -0,0 +1,8 @@
+// RUN: mlir-opt --split-input-file --tosa-to-arith="include-apply-rescale=true use-32-bit=true" %s -verify-diagnostics
+
+// CHECK-LABEL: @apply_scale_unsupported_inexact_round
+func.func @apply_scale_unsupported_inexact_round(%arg0 : i64, %arg1 : i32, %arg2 : i8) -> (i32) {
+ // expected-error@+1 {{failed to legalize operation 'tosa.apply_scale'}}
+ %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "INEXACT_ROUND"} : (i64, i32, i8) -> i32
+ return %res : i32
+}
diff --git a/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
index 14f811727c456..db68ca40879f4 100644
--- a/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
+++ b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
@@ -67,7 +67,7 @@ func.func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
// CHECK-DAG: %[[LOWALIGN:.+]] = arith.select %[[OVER31]], %[[C0]], %[[LOR]]
// CHECK-DAG: %[[RESULT:.+]] = arith.addi %[[LOWALIGN]], %[[HIALIGN]]
// CHECK: return %[[RESULT]]
- %res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i32, i32, i8) -> i32
+ %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "DOUBLE_ROUND"} : (i32, i32, i8) -> i32
return %res : i32
}
@@ -77,7 +77,7 @@ func.func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
// SCALE: tosa.apply_scale
func.func @apply_scale_test_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>, %arg2 : vector<4xi8>) -> (vector<4xi32>) {
// CHECK-NOT: "tosa.apply_scale"
- %res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (vector<4xi32>, vector<4xi32>, vector<4xi8>) -> vector<4xi32>
+ %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "DOUBLE_ROUND"} : (vector<4xi32>, vector<4xi32>, vector<4xi8>) -> vector<4xi32>
return %res : vector<4xi32>
}
@@ -115,7 +115,7 @@ func.func @apply_scale_test_i48(%arg0 : i48, %arg1 : i32, %arg2 : i8) -> (i32) {
// CHECK-DAG: %[[SHR:.+]] = arith.shrsi %[[RES64]], %[[S64]]
// CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i64 to i32
// CHECK: return %[[TRUNC]]
- %res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i48, i32, i8) -> i32
+ %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "DOUBLE_ROUND"} : (i48, i32, i8) -> i32
return %res : i32
}
@@ -152,6 +152,6 @@ func.func @apply_scale_test_i64(%arg0 : i64, %arg1 : i32, %arg2 : i8) -> (i32) {
// CHECK-DAG: %[[SHR:.+]] = arith.shrsi %[[RES64]], %[[S64]]
// CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i64 to i32
// CHECK: return %[[TRUNC]]
- %res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i64, i32, i8) -> i32
+ %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "DOUBLE_ROUND"} : (i64, i32, i8) -> i32
return %res : i32
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 77687b83e5e3c..54c6ed994e947 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -36,7 +36,7 @@ func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32,
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
// expected-error@+1 {{failed to legalize operation 'tosa.rescale'}}
- %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 5bb4a3bddb51b..a89da9a2b9fed 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -411,7 +411,7 @@ func.func @avg_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>) {
// CHECK: %[[TRUNC_SHIFT:.+]] = arith.trunci %[[SUB]]
// CHECK: %[[C30:.+]] = arith.constant 30
// CHECK: %[[SHIFT:.+]] = arith.addi %[[TRUNC_SHIFT]], %[[C30]] : i8
- // CHECK: %[[SCALED:.+]] = tosa.apply_scale %[[IN]], %[[TRUNC_MUL]], %[[SHIFT]] {double_round = false}
+ // CHECK: %[[SCALED:.+]] = tosa.apply_scale %[[IN]], %[[TRUNC_MUL]], %[[SHIFT]] {rounding_mode = "SINGLE_ROUND"}
// Perform the normalization.
// CHECK: %[[CMIN:.+]] = arith.constant -128
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index a3ed8c2805282..596069ad7f53d 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1141,7 +1141,7 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
// CHECK: [[C22:%.+]] = arith.constant 22
// CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
- // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {double_round = false}
+ // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
// CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
// CHECK-DAG: [[CMAX:%.+]] = arith.constant 127
@@ -1151,7 +1151,7 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
// CHECK-DAG: linalg.yield [[TRUNC]]
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
- %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
// CHECK: return
return
@@ -1172,7 +1172,7 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
// CHECK: [[C22:%.+]] = arith.constant 22
// CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
- // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {double_round = false}
+ // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
// CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
// CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
@@ -1182,7 +1182,7 @@ func.func @rescale_i8_unsigned_output(%arg0 : t...
[truncated]
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
lhutton1
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but I won't explicitly approve since I authored some of this patch
GeorgeARM
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This commit adds a concept of EXT-DOUBLEROUND and EXT-INEXACTROUND to the dialect. It also converts the "double_round" attribute on rescale to a string type "rounding_mode" attribute with the following options: "DOUBLE_ROUND", "SINGLE_ROUND", "INEXACT_ROUND". The validation pass has been updated to ensure "DOUBLE_ROUND" and "INEXACT_ROUND" are only valid when their extensions are available. Finally, lowerings to arith and linalg have been updated such that a lowering for "INEXACT_ROUND" is not currently supported. Co-authored-by: TatWai Chong <[email protected]>
|
Rebase. |
Changes needed for llvm/llvm-project#130337 and llvm/llvm-project#129720 PiperOrigin-RevId: 738783395
Changes needed for llvm/llvm-project#130337 and llvm/llvm-project#129720 PiperOrigin-RevId: 738783395
Changes needed for llvm/llvm-project#130337, llvm/llvm-project#129720, and llvm/llvm-project#130340 FUTURE_COPYBARA_INTEGRATE_REVIEW=#23793 from Tixxx:tixxx/cp_sync 3a8a643 PiperOrigin-RevId: 738783395
Changes needed for llvm/llvm-project#130337, llvm/llvm-project#129720, and llvm/llvm-project#130340 PiperOrigin-RevId: 738783395
Changes needed for llvm/llvm-project#130337, llvm/llvm-project#129720, and llvm/llvm-project#130340 FUTURE_COPYBARA_INTEGRATE_REVIEW=#23793 from Tixxx:tixxx/cp_sync 3a8a643 PiperOrigin-RevId: 738783395
Changes needed for llvm/llvm-project#130337, llvm/llvm-project#129720, and llvm/llvm-project#130340 PiperOrigin-RevId: 738783395
Changes needed for llvm/llvm-project#130337, llvm/llvm-project#129720, and llvm/llvm-project#130340 PiperOrigin-RevId: 739122578
Changes needed for llvm/llvm-project#130337, llvm/llvm-project#129720, and llvm/llvm-project#130340 PiperOrigin-RevId: 739122578
Changes needed for llvm/llvm-project#130337, llvm/llvm-project#129720, and llvm/llvm-project#130340 PiperOrigin-RevId: 739122578
Changes needed for llvm/llvm-project#130337, llvm/llvm-project#129720, and llvm/llvm-project#130340 PiperOrigin-RevId: 739122578
Changes needed for llvm/llvm-project#130337, llvm/llvm-project#129720, and llvm/llvm-project#130340 PiperOrigin-RevId: 739122578
Adds a concept of EXT-DOUBLEROUND and EXT-INEXACTROUND
to the dialect. It also converts the "double_round" attribute on rescale
to a string type "rounding_mode" attribute with the following options:
"DOUBLE_ROUND", "SINGLE_ROUND", "INEXACT_ROUND".
The validation pass has been updated to ensure "DOUBLE_ROUND"
and "INEXACT_ROUND" are only valid when their extensions are
available.
Finally, lowerings to arith and linalg have been updated such that a
lowering for "INEXACT_ROUND" is not currently supported.