Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic,
// FFT : Fast Fourier Transform operations.
// VARIABLE : Stateful variable operations.
// CONTROLFLOW : Control Flow operations.
// DOUBLEROUND : Adds double rounding support to the RESCALE operator.
// INEXACTROUND : Adds inexact rounding support to the RESCALE operator.
//===----------------------------------------------------------------------===//

def Tosa_NONE : I32EnumAttrCase<"none", 0>;
Expand All @@ -241,11 +243,14 @@ def Tosa_EXT_FP8E5M2 : I32EnumAttrCase<"fp8e5m2", 5>;
def Tosa_EXT_FFT : I32EnumAttrCase<"fft", 6>;
def Tosa_EXT_VARIABLE : I32EnumAttrCase<"variable", 7>;
def Tosa_EXT_CONTROLFLOW : I32EnumAttrCase<"controlflow", 8>;
def Tosa_EXT_DOUBLEROUND : I32EnumAttrCase<"doubleround", 9>;
def Tosa_EXT_INEXACTROUND : I32EnumAttrCase<"inexactround", 10>;

def Tosa_ExtensionAttr
: Tosa_I32EnumAttr<"Extension", "supported TOSA extensions", "ext", [
Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16, Tosa_EXT_FP8E4M3,
Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW, Tosa_EXT_NONE
Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW,
Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND, Tosa_EXT_NONE
]>;

def Tosa_ExtensionArrayAttr
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2347,7 +2347,7 @@ def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
I32Attr:$input_zp,
I32Attr:$output_zp,
BoolAttr:$scale32,
BoolAttr:$double_round,
Tosa_RoundingTypeAttr:$rounding_mode,
BoolAttr:$per_channel,
BoolAttr: $input_unsigned,
BoolAttr: $output_unsigned
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ class TosaProfileCompliance {
switch (ext) {
case Extension::int16:
case Extension::int4:
case Extension::doubleround:
case Extension::inexactround:
return {Profile::pro_int};
case Extension::bf16:
case Extension::fp8e4m3:
Expand Down
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,13 @@ def Tosa_NanPropagationAttr : StringBasedAttr<
"::llvm::cast<StringAttr>($_self).getValue() == \"IGNORE\"">,
"Supported NaN propagation strategies">;

// Rounding mode for tosa.rescale
def Tosa_RoundingTypeAttr : StringBasedAttr<
CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"SINGLE_ROUND\" || " #
"::llvm::cast<StringAttr>($_self).getValue() == \"INEXACT_ROUND\" || " #
"::llvm::cast<StringAttr>($_self).getValue() == \"DOUBLE_ROUND\"">,
"Supported rounding modes">;

def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;

// Tensor to buffer types.
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def Tosa_ApplyScaleOp :
Tosa_IntLike:$value,
Tosa_IntLike:$multiplier,
Tosa_Int8Like:$shift,
BoolAttr:$double_round
Tosa_RoundingTypeAttr:$rounding_mode
);

let results = (outs
Expand Down
14 changes: 12 additions & 2 deletions mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ class ApplyScaleGenericOpConverter

LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
PatternRewriter &rewriter) const final {
StringRef roundingMode = op.getRoundingMode();
if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
return failure();
}

Location loc = op.getLoc();
Value value = op.getValue();
Value multiplier32 = op.getMultiplier();
Expand Down Expand Up @@ -96,7 +101,7 @@ class ApplyScaleGenericOpConverter
multiply64 = rewriter.create<arith::AddIOp>(loc, multiply64, round);

// Apply double rounding if necessary.
if (op.getDoubleRound()) {
if (op.getRoundingMode() == "DOUBLE_ROUND") {
int64_t roundInt = 1 << 30;
Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
Expand Down Expand Up @@ -125,6 +130,11 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {

LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
PatternRewriter &rewriter) const final {
StringRef roundingMode = op.getRoundingMode();
if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
return failure();
}

Location loc = op.getLoc();

Type resultTy = op.getType();
Expand Down Expand Up @@ -170,7 +180,7 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
rewriter.create<arith::SelectOp>(loc, shiftOver32, shiftHighR, zero32);

// Conditionally perform our double round.
if (op.getDoubleRound()) {
if (op.getRoundingMode() == "DOUBLE_ROUND") {
Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
Value valuePositive = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, value32, zero32);
Expand Down
16 changes: 12 additions & 4 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(

auto result = rewriter.create<tosa::ApplyScaleOp>(
loc, rewriter.getI32Type(), a, b, shiftConst,
rewriter.getBoolAttr(false));
rewriter.getStringAttr("SINGLE_ROUND"));

if (elementTy.isInteger(32))
return result;
Expand Down Expand Up @@ -1374,7 +1374,11 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
unsigned rank = inputTy.getRank();

// This is an illegal configuration. terminate and log an error
if (op.getDoubleRound() && !op.getScale32())
if (op.getRoundingMode() == "INEXACT_ROUND")
return rewriter.notifyMatchFailure(
op, "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
"currently supported");
if (op.getRoundingMode() == "DOUBLE_ROUND" && !op.getScale32())
return rewriter.notifyMatchFailure(
op, "tosa.rescale requires scale32 for double_round to be true");

Expand Down Expand Up @@ -1418,9 +1422,13 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {

// Double round only occurs if shift is greater than 31, check that this
// is ever true.

bool doubleRound =
op.getDoubleRound() &&
op.getRoundingMode() == "DOUBLE_ROUND" &&
llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
StringAttr roundingMode = doubleRound
? rewriter.getStringAttr("DOUBLE_ROUND")
: rewriter.getStringAttr("SINGLE_ROUND");

SmallVector<AffineMap> indexingMaps = {
rewriter.getMultiDimIdentityMap(rank)};
Expand Down Expand Up @@ -1516,7 +1524,7 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {

value = nestedBuilder.create<tosa::ApplyScaleOp>(
loc, nestedBuilder.getI32Type(), value, multiplier, shift,
nestedBuilder.getBoolAttr(doubleRound));
roundingMode);

// Move to the new zero-point.
value =
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1031,9 +1031,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {

auto scaled =
rewriter
.create<tosa::ApplyScaleOp>(loc, rewriter.getI32Type(),
poolVal, multiplier, shift,
rewriter.getBoolAttr(false))
.create<tosa::ApplyScaleOp>(
loc, rewriter.getI32Type(), poolVal, multiplier, shift,
rewriter.getStringAttr("SINGLE_ROUND"))
.getResult();

// If we have quantization information we need to apply output
Expand Down
33 changes: 32 additions & 1 deletion mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
}

LogicalResult applyLevelCheck(Operation *op);
LogicalResult applyAttributeCheck(Operation *op);

// check variable read/write data types against variable declarations
LogicalResult applyVariableCheck(Operation *op);
Expand Down Expand Up @@ -386,6 +387,25 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
return true;
}

bool attributeCheckRescale(Operation *op) {
if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
if (rescale.getRoundingMode() == "DOUBLE_ROUND" &&
!targetEnv.allows(Extension::doubleround)) {
op->emitOpError()
<< "failed attribute check: rounding_mode = DOUBLE_ROUND "
<< "requires extension [doubleround]";
return false;
} else if (rescale.getRoundingMode() == "INEXACT_ROUND" &&
!targetEnv.allows(Extension::inexactround)) {
op->emitOpError()
<< "failed attribute check: rounding_mode = INEXACT_ROUND "
<< "requires extension [inexactround]";
return false;
}
}
return true;
}

// configure profile and level values from pass options profileName and
// levelName
void configLevelAndProfile() {
Expand Down Expand Up @@ -415,7 +435,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
} else {
llvm::errs() << "unknown TOSA extension name passed in: " << ext
<< ", supported extension are int16, int4, bf16, "
<< "fp8e4m3, fp8e5m2, fft, variable and controlflow\n";
<< "fp8e4m3, fp8e5m2, fft, variable, controlflow, "
<< "doubleround and inexactround\n";
return signalPassFailure();
}
}
Expand Down Expand Up @@ -642,6 +663,12 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
return success();
}

LogicalResult TosaValidation::applyAttributeCheck(Operation *op) {
if (!attributeCheckRescale(op))
return failure();
return success();
}

inline bool CompatibleTypes(const mlir::Type &type,
const mlir::Type &declaredType) {
// for now, simply use type equality comparison
Expand Down Expand Up @@ -936,6 +963,10 @@ void TosaValidation::runOnOperation() {
if (failed(applyLevelCheck(op)))
signalPassFailure();

// check additional attribute restrictions
if (failed(applyAttributeCheck(op)))
signalPassFailure();

// do variable type checks
if (failed(applyVariableCheck(op)))
signalPassFailure();
Expand Down
8 changes: 8 additions & 0 deletions mlir/test/Conversion/TosaToArith/tosa-to-arith-invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// RUN: mlir-opt --split-input-file --tosa-to-arith="include-apply-rescale=true use-32-bit=true" %s -verify-diagnostics

// CHECK-LABEL: @apply_scale_unsupported_inexact_round
func.func @apply_scale_unsupported_inexact_round(%arg0 : i64, %arg1 : i32, %arg2 : i8) -> (i32) {
// expected-error@+1 {{failed to legalize operation 'tosa.apply_scale'}}
%res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "INEXACT_ROUND"} : (i64, i32, i8) -> i32
return %res : i32
}
8 changes: 4 additions & 4 deletions mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func.func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
// CHECK-DAG: %[[LOWALIGN:.+]] = arith.select %[[OVER31]], %[[C0]], %[[LOR]]
// CHECK-DAG: %[[RESULT:.+]] = arith.addi %[[LOWALIGN]], %[[HIALIGN]]
// CHECK: return %[[RESULT]]
%res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i32, i32, i8) -> i32
%res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "DOUBLE_ROUND"} : (i32, i32, i8) -> i32
return %res : i32
}

Expand All @@ -77,7 +77,7 @@ func.func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
// SCALE: tosa.apply_scale
func.func @apply_scale_test_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>, %arg2 : vector<4xi8>) -> (vector<4xi32>) {
// CHECK-NOT: "tosa.apply_scale"
%res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (vector<4xi32>, vector<4xi32>, vector<4xi8>) -> vector<4xi32>
%res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "DOUBLE_ROUND"} : (vector<4xi32>, vector<4xi32>, vector<4xi8>) -> vector<4xi32>
return %res : vector<4xi32>
}

Expand Down Expand Up @@ -115,7 +115,7 @@ func.func @apply_scale_test_i48(%arg0 : i48, %arg1 : i32, %arg2 : i8) -> (i32) {
// CHECK-DAG: %[[SHR:.+]] = arith.shrsi %[[RES64]], %[[S64]]
// CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i64 to i32
// CHECK: return %[[TRUNC]]
%res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i48, i32, i8) -> i32
%res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "DOUBLE_ROUND"} : (i48, i32, i8) -> i32
return %res : i32
}

Expand Down Expand Up @@ -152,6 +152,6 @@ func.func @apply_scale_test_i64(%arg0 : i64, %arg1 : i32, %arg2 : i8) -> (i32) {
// CHECK-DAG: %[[SHR:.+]] = arith.shrsi %[[RES64]], %[[S64]]
// CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i64 to i32
// CHECK: return %[[TRUNC]]
%res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i64, i32, i8) -> i32
%res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "DOUBLE_ROUND"} : (i64, i32, i8) -> i32
return %res : i32
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32,
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
// expected-error@+1 {{failed to legalize operation 'tosa.rescale'}}
%0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
%0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ func.func @avg_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>) {
// CHECK: %[[TRUNC_SHIFT:.+]] = arith.trunci %[[SUB]]
// CHECK: %[[C30:.+]] = arith.constant 30
// CHECK: %[[SHIFT:.+]] = arith.addi %[[TRUNC_SHIFT]], %[[C30]] : i8
// CHECK: %[[SCALED:.+]] = tosa.apply_scale %[[IN]], %[[TRUNC_MUL]], %[[SHIFT]] {double_round = false}
// CHECK: %[[SCALED:.+]] = tosa.apply_scale %[[IN]], %[[TRUNC_MUL]], %[[SHIFT]] {rounding_mode = "SINGLE_ROUND"}

// Perform the normalization.
// CHECK: %[[CMIN:.+]] = arith.constant -128
Expand Down
Loading