Skip to content

Commit 9c357ad

Browse files
authored
fix(mlir,tosa): restore unrealized casts when lowering tosa.rescale to linalg (#15)
Along with the changes to rescale op attributes, commit 7208649 dropped the builtin casts between signed and signless types. However, explicitly unsigned types are still legal input and output values from the TOSA IR perspective, and TF-to-TOSA's `ConvertUint8ToInt8` pass generates just that. The change adds back the casts when the unsigned<->signless semantics are explicit in the underlying tensor types. This prevents the conversion routine from trying to generate illegal `arith` casts that are contrained to signless types. Whether the `arith` casts themselves are signed or unsigned should still depend on the rescale's `*_unsigned` attribute values.
1 parent 96007f1 commit 9c357ad

File tree

2 files changed

+127
-10
lines changed

2 files changed

+127
-10
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1505,6 +1505,15 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
15051505
: blockArgs[multiplierArg];
15061506
Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
15071507

1508+
if (valueTy.isUnsignedInteger()) {
1509+
value = nestedBuilder
1510+
.create<UnrealizedConversionCastOp>(
1511+
nestedLoc,
1512+
nestedBuilder.getIntegerType(
1513+
valueTy.getIntOrFloatBitWidth()),
1514+
value)
1515+
.getResult(0);
1516+
}
15081517
if (valueTy.getIntOrFloatBitWidth() < 32) {
15091518
if (op.getInputUnsigned()) {
15101519
value = nestedBuilder.create<arith::ExtUIOp>(
@@ -1554,6 +1563,12 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
15541563
value);
15551564
}
15561565

1566+
if (outIntType.isUnsignedInteger()) {
1567+
value = nestedBuilder
1568+
.create<UnrealizedConversionCastOp>(nestedLoc,
1569+
outIntType, value)
1570+
.getResult(0);
1571+
}
15571572
nestedBuilder.create<linalg::YieldOp>(loc, value);
15581573
});
15591574

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 112 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,25 +1152,60 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
11521152
// -----
11531153
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
11541154

1155-
// CHECK-LABEL: @rescale_i8_unsigned_output
1155+
// CHECK-LABEL: @rescale_i8_unsigned_output_explicit
11561156
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1157-
func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
1157+
func.func @rescale_i8_unsigned_output_explicit(%arg0 : tensor<2xi8>) -> () {
1158+
// CHECK: [[C0:%.+]] = arith.constant 19689
1159+
// CHECK: [[C1:%.+]] = arith.constant 15
1160+
// CHECK: [[INIT:%.+]] = tensor.empty()
1161+
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xui8>)
1162+
// CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: ui8):
1163+
// CHECK-DAG: [[C17:%.+]] = arith.constant 17
1164+
// CHECK-DAG: [[C234:%.+]] = arith.constant 234
1165+
// CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
1166+
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
1167+
// CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
1168+
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]]
1169+
// CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
1170+
// CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
1171+
// CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
1172+
// CHECK: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
1173+
// CHECK: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
1174+
// CHECK: [[TRUNC_ITOU:%.+]] = builtin.unrealized_conversion_cast [[TRUNC]] : i8 to ui8
1175+
// CHECK: linalg.yield [[TRUNC_ITOU]]
1176+
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
1177+
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
1178+
%input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
1179+
%output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
1180+
%1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xui8>
1181+
1182+
// CHECK: return
1183+
return
1184+
}
1185+
1186+
// -----
1187+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
1188+
1189+
// CHECK-LABEL: @rescale_i8_unsigned_output_implicit
1190+
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1191+
func.func @rescale_i8_unsigned_output_implicit(%arg0 : tensor<2xi8>) -> () {
11581192
// CHECK: [[C0:%.+]] = arith.constant 19689
11591193
// CHECK: [[C1:%.+]] = arith.constant 15
11601194
// CHECK: [[INIT:%.+]] = tensor.empty()
11611195
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
11621196
// CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
1163-
// CHECK: [[C17:%.+]] = arith.constant 17
1164-
// CHECK: [[C234:%.+]] = arith.constant 234
1197+
// CHECK-DAG: [[C17:%.+]] = arith.constant 17
1198+
// CHECK-DAG: [[C234:%.+]] = arith.constant 234
11651199
// CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
11661200
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
11671201
// CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
11681202
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]]
11691203
// CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
11701204
// CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
11711205
// CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
1172-
// CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
1173-
// CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
1206+
// CHECK: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
1207+
// CHECK: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
1208+
// CHECK-NOT: builtin.unrealized_conversion_cast [[TRUNC]] : i8 to i8
11741209
// CHECK: linalg.yield [[TRUNC]]
11751210
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
11761211
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
@@ -1230,19 +1265,52 @@ func.func @rescale_dyn(%arg0 : tensor<1x?x?x32xi32>) -> () {
12301265
}
12311266

12321267
// -----
1268+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
1269+
1270+
// CHECK-LABEL: @rescale_i8_unsigned_input_explicit
1271+
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1272+
func.func @rescale_i8_unsigned_input_explicit(%arg0 : tensor<2xui8>) -> () {
1273+
// CHECK: [[C0:%.+]] = arith.constant 19689
1274+
// CHECK: [[C1:%.+]] = arith.constant 15
1275+
// CHECK: [[INIT:%.+]] = tensor.empty()
1276+
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xui8>) outs([[INIT]] : tensor<2xi8>)
1277+
// CHECK: ^bb0([[IN:%.+]]: ui8, [[UNUSED:%.+]]: i8):
1278+
// CHECK-DAG: [[C17:%.+]] = arith.constant 17
1279+
// CHECK-DAG: [[C22:%.+]] = arith.constant 22
1280+
// CHECK-DAG: [[IN_UTOI:%.+]] = builtin.unrealized_conversion_cast [[IN]] : ui8 to i8
1281+
// CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN_UTOI]]
1282+
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
1283+
// CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
1284+
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
1285+
// CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
1286+
// CHECK-DAG: [[CMAX:%.+]] = arith.constant 127
1287+
// CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
1288+
// CHECK: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
1289+
// CHECK: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
1290+
// CHECK: linalg.yield [[TRUNC]]
1291+
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
1292+
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
1293+
%input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
1294+
%output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
1295+
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xui8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
1296+
1297+
return
1298+
}
12331299

1300+
// -----
12341301
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
12351302

1236-
// CHECK-LABEL: @rescale_i8_unsigned_input
1303+
// CHECK-LABEL: @rescale_i8_unsigned_input_implicit
12371304
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1238-
func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
1305+
func.func @rescale_i8_unsigned_input_implicit(%arg0 : tensor<2xi8>) -> () {
12391306
// CHECK: [[C0:%.+]] = arith.constant 19689
12401307
// CHECK: [[C1:%.+]] = arith.constant 15
12411308
// CHECK: [[INIT:%.+]] = tensor.empty()
12421309
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
12431310
// CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
1244-
// CHECK: [[C17:%.+]] = arith.constant 17
1245-
// CHECK: [[C22:%.+]] = arith.constant 22
1311+
// CHECK-NOT: builtin.unrealized_conversion_cast [[IN]] : i8 to i8
1312+
// CHECK-DAG: [[C17:%.+]] = arith.constant 17
1313+
// CHECK-DAG: [[C22:%.+]] = arith.constant 22
12461314
// CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN]]
12471315
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
12481316
// CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
@@ -1262,6 +1330,40 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
12621330
return
12631331
}
12641332

1333+
// -----
1334+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
1335+
1336+
// CHECK-LABEL: @rescale_i8_unsigned_input_output_explicit
1337+
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1338+
func.func @rescale_i8_unsigned_input_output_explicit(%arg0 : tensor<2xui8>) -> () {
1339+
// CHECK: [[C0:%.+]] = arith.constant 19689
1340+
// CHECK: [[C1:%.+]] = arith.constant 15
1341+
// CHECK: [[INIT:%.+]] = tensor.empty()
1342+
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xui8>) outs([[INIT]] : tensor<2xui8>)
1343+
// CHECK: ^bb0([[IN:%.+]]: ui8, [[UNUSED:%.+]]: ui8):
1344+
// CHECK-DAG: [[C17:%.+]] = arith.constant 17
1345+
// CHECK-DAG: [[C22:%.+]] = arith.constant 22
1346+
// CHECK-DAG: [[IN_UTOI:%.+]] = builtin.unrealized_conversion_cast [[IN]] : ui8 to i8
1347+
// CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN_UTOI]]
1348+
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
1349+
// CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
1350+
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
1351+
// CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
1352+
// CHECK-DAG: [[CMAX:%.+]] = arith.constant 127
1353+
// CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
1354+
// CHECK: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
1355+
// CHECK: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
1356+
// CHECK: [[TRUNC_ITOU:%.+]] = builtin.unrealized_conversion_cast [[TRUNC]] : i8 to ui8
1357+
// CHECK: linalg.yield [[TRUNC_ITOU]]
1358+
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
1359+
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
1360+
%input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
1361+
%output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
1362+
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xui8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xui8>
1363+
1364+
return
1365+
}
1366+
12651367
// -----
12661368

12671369
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>

0 commit comments

Comments
 (0)