@@ -1132,11 +1132,21 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
11321132 // CHECK-DAG: linalg.yield [[TRUNC]]
11331133 %0 = tosa.rescale %arg0 {input_zp = 17 : i32 , output_zp = 22 : i32 , multiplier = array<i32 : 19689 >, shift = array<i8 : 15 >, scale32 = false , double_round = false , per_channel = false } : (tensor <2 xi8 >) -> tensor <2 xi8 >
11341134
1135+ // CHECK: return
1136+ return
1137+ }
1138+
1139+ // -----
1140+ // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
1141+
1142+ // CHECK-LABEL: @rescale_i8_unsigned_output
1143+ // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1144+ func.func @rescale_i8_unsigned_output (%arg0 : tensor <2 xi8 >) -> () {
11351145 // CHECK: [[C0:%.+]] = arith.constant 19689
11361146 // CHECK: [[C1:%.+]] = arith.constant 15
11371147 // CHECK: [[INIT:%.+]] = tensor.empty()
1138- // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xui8 >)
1139- // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: ui8 ):
1148+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8 >)
1149+ // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8 ):
11401150 // CHECK: [[C17:%.+]] = arith.constant 17
11411151 // CHECK: [[C22:%.+]] = arith.constant 22
11421152 // CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
@@ -1148,9 +1158,8 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
11481158 // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
11491159 // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
11501160 // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
1151- // CHECK-DAG: [[CAST:%.+]] = builtin.unrealized_conversion_cast [[TRUNC]] : i8 to ui8
1152- // CHECK: linalg.yield [[CAST]]
1153- %1 = tosa.rescale %arg0 {input_zp = 17 : i32 , output_zp = 22 : i32 , multiplier = array<i32 : 19689 >, shift = array<i8 : 15 >, scale32 = false , double_round = false , per_channel = false } : (tensor <2 xi8 >) -> tensor <2 xui8 >
1161+ // CHECK: linalg.yield [[TRUNC]]
1162+ %1 = tosa.rescale %arg0 {input_zp = 17 : i32 , output_zp = 22 : i32 , multiplier = array<i32 : 19689 >, shift = array<i8 : 15 >, scale32 = false , double_round = false , per_channel = false , output_unsigned = true } : (tensor <2 xi8 >) -> tensor <2 xi8 >
11541163
11551164 // CHECK: return
11561165 return
@@ -1171,9 +1180,9 @@ func.func @rescale_i8_dyn_batch(%arg0 : tensor<?x2xi8>) -> () {
11711180
11721181 // CHECK: %[[C0:.+]] = arith.constant 0
11731182 // CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
1174- // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xui8 >
1175- // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xui8 >)
1176- %1 = tosa.rescale %arg0 {input_zp = 17 : i32 , output_zp = 22 : i32 , multiplier = array<i32 : 19689 >, shift = array<i8 : 15 >, scale32 = false , double_round = false , per_channel = false } : (tensor <?x2 xi8 >) -> tensor <?x 2 x ui8 >
1183+ // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8 >
1184+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8 >)
1185+ %1 = tosa.rescale %arg0 {input_zp = 17 : i32 , output_zp = 22 : i32 , multiplier = array<i32 : 19689 >, shift = array<i8 : 15 >, scale32 = false , double_round = false , per_channel = false , output_unsigned = true } : (tensor <?x2 xi8 >) -> tensor <?x 2 x i8 >
11771186
11781187 return
11791188}
@@ -1199,18 +1208,17 @@ func.func @rescale_dyn(%arg0 : tensor<1x?x?x32xi32>) -> () {
11991208
12001209// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
12011210
1202- // CHECK-LABEL: @rescale_ui8
1211+ // CHECK-LABEL: @rescale_i8_unsigned_input
12031212// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1204- func.func @rescale_ui8 (%arg0 : tensor <2 x ui8 >) -> () {
1213+ func.func @rescale_i8_unsigned_input (%arg0 : tensor <2 x i8 >) -> () {
12051214 // CHECK: [[C0:%.+]] = arith.constant 19689
12061215 // CHECK: [[C1:%.+]] = arith.constant 15
12071216 // CHECK: [[INIT:%.+]] = tensor.empty()
1208- // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xui8 >) outs([[INIT]] : tensor<2xi8>)
1209- // CHECK: ^bb0([[IN:%.+]]: ui8 , [[UNUSED:%.+]]: i8):
1217+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8 >) outs([[INIT]] : tensor<2xi8>)
1218+ // CHECK: ^bb0([[IN:%.+]]: i8 , [[UNUSED:%.+]]: i8):
12101219 // CHECK: [[C17:%.+]] = arith.constant 17
12111220 // CHECK: [[C22:%.+]] = arith.constant 22
1212- // CHECK-DAG: [[CAST:%.+]] = builtin.unrealized_conversion_cast [[IN]] : ui8 to i8
1213- // CHECK-DAG: [[IN32:%.+]] = arith.extui [[CAST]]
1221+ // CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN]]
12141222 // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
12151223 // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {double_round = false}
12161224 // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
@@ -1220,7 +1228,7 @@ func.func @rescale_ui8(%arg0 : tensor<2xui8>) -> () {
12201228 // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
12211229 // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
12221230 // CHECK: linalg.yield [[TRUNC]]
1223- %0 = tosa.rescale %arg0 {input_zp = 17 : i32 , output_zp = 22 : i32 , multiplier = array<i32 : 19689 >, shift = array<i8 : 15 >, scale32 = false , double_round = false , per_channel = false } : (tensor <2 x ui8 >) -> tensor <2 xi8 >
1231+ %0 = tosa.rescale %arg0 {input_zp = 17 : i32 , output_zp = 22 : i32 , multiplier = array<i32 : 19689 >, shift = array<i8 : 15 >, scale32 = false , double_round = false , per_channel = false , input_unsigned = true } : (tensor <2 x i8 >) -> tensor <2 xi8 >
12241232
12251233 return
12261234}
0 commit comments