@@ -1594,6 +1594,35 @@ func.func @rescale_no_const_per_channel(%arg0 : tensor<2xi8>, %arg1 : tensor<2xi
15941594
15951595// -----
15961596
1597+ // CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)>
1598+ // CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> ()>
1599+ // CHECK-LABEL: @rescale_no_const_per_channel_output_zp_ui8
1600+ // CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
1601+ // CHECK-SAME: [[ARG1:%[0-9a-zA-Z_]*]]
1602+ // CHECK-SAME: [[ARG2:%[0-9a-zA-Z_]*]]
1603+ func.func @rescale_no_const_per_channel_output_zp_ui8 (%arg0 : tensor <2 xi8 >, %arg1 : tensor <2 xi32 >, %arg2 : tensor <2 xi8 >, %input_zp : tensor <1 xi8 >, %output_zp : tensor <1 xui8 >) -> (tensor <2 xui8 >) {
1604+ // CHECK: [[INPUT_ZP:%.+]] = tensor.collapse_shape %arg3 [] : tensor<1xi8> into tensor<i8>
1605+ // CHECK: [[OUTPUT_ZP:%.+]] = tensor.collapse_shape %arg4 [] : tensor<1xui8> into tensor<ui8>
1606+ // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xui8>
1607+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]], [[ARG1]], [[ARG2]], [[INPUT_ZP]], [[OUTPUT_ZP]] : tensor<2xi8>, tensor<2xi32>, tensor<2xi8>, tensor<i8>, tensor<ui8>) outs([[INIT]] : tensor<2xui8>) {
1608+ // CHECK: ^bb0([[ARG0:%.*]]: i8, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i8, [[ARG3:%.*]]: i8, [[ARG4:%.*]]: ui8, [[OUT:%.*]]: ui8):
1609+ // CHECK: [[INPUT_ZP_I32:%.+]] = arith.extsi [[ARG3]] : i8 to i32
1610+ // CHECK: [[INPUT_ZP_I8:%.+]] = builtin.unrealized_conversion_cast [[ARG4]] : ui8 to i8
1611+ // CHECK: [[OUTPUT_ZP_I32:%.+]] = arith.extui [[INPUT_ZP_I8]] : i8 to i32
1612+ // CHECK: [[ARG0_I32:%.+]] = arith.extsi [[ARG0]] : i8 to i32
1613+ // CHECK: [[TMP1:%.+]] = arith.subi [[ARG0_I32]], [[INPUT_ZP_I32]] : i32
1614+ // CHECK: [[TMP2:%.+]] = tosa.apply_scale [[TMP1]], [[ARG1]], [[ARG2]] {rounding_mode = DOUBLE_ROUND} : (i32, i32, i8) -> i32
1615+ // CHECK: [[TMP3:%.+]] = arith.addi [[TMP2]], [[OUTPUT_ZP_I32]] : i32
1616+ // CHECK: %c0_i32 = arith.constant 0 : i32
1617+ // CHECK: %c255_i32 = arith.constant 255 : i32
1618+ // CHECK: [[MAX:%.+]] = arith.maxsi %c0_i32, [[TMP3]] : i32
1619+ // CHECK: [[MIN:%.+]] = arith.minsi %c255_i32, [[MAX]] : i32
1620+ %0 = tosa.rescale %arg0 , %arg1 , %arg2 , %input_zp , %output_zp {scale32 = true , rounding_mode = DOUBLE_ROUND , per_channel = true , input_unsigned = false , output_unsigned = true } : (tensor <2 xi8 >, tensor <2 xi32 >, tensor <2 xi8 >, tensor <1 xi8 >, tensor <1 xui8 >) -> tensor <2 xui8 >
1621+ return %0 : tensor <2 xui8 >
1622+ }
1623+
1624+ // -----
1625+
15971626// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
15981627
15991628// CHECK-LABEL: @reverse
0 commit comments