Skip to content

Commit 8a47f3c

Browse files
committed
Add rescale_no_const_per_channel_output_zp_ui8 test case
1 parent 693bda0 commit 8a47f3c

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1527,6 +1527,13 @@ static Value getI32OutputZp(OpBuilder &builder, Type valueTy,
15271527
result = blockArgs[oZpArg];
15281528
auto zpTy = result.getType();
15291529
if (zpTy.getIntOrFloatBitWidth() < 32) {
1530+
if (zpTy.isUnsignedInteger()) {
1531+
result =
1532+
UnrealizedConversionCastOp::create(
1533+
builder, loc,
1534+
builder.getIntegerType(zpTy.getIntOrFloatBitWidth()), result)
1535+
.getResult(0);
1536+
}
15301537
if (zpTy.isUnsignedInteger()) {
15311538
return builder.create<arith::ExtUIOp>(loc, builder.getI32Type(),
15321539
result);

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1594,6 +1594,35 @@ func.func @rescale_no_const_per_channel(%arg0 : tensor<2xi8>, %arg1 : tensor<2xi
15941594

15951595
// -----
15961596

1597+
// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)>
1598+
// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> ()>
1599+
// CHECK-LABEL: @rescale_no_const_per_channel_output_zp_ui8
1600+
// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
1601+
// CHECK-SAME: [[ARG1:%[0-9a-zA-Z_]*]]
1602+
// CHECK-SAME: [[ARG2:%[0-9a-zA-Z_]*]]
1603+
func.func @rescale_no_const_per_channel_output_zp_ui8(%arg0 : tensor<2xi8>, %arg1 : tensor<2xi32>, %arg2 : tensor<2xi8>, %input_zp : tensor<1xi8>, %output_zp : tensor<1xui8>) -> (tensor<2xui8>) {
1604+
// CHECK: [[INPUT_ZP:%.+]] = tensor.collapse_shape %arg3 [] : tensor<1xi8> into tensor<i8>
1605+
// CHECK: [[OUTPUT_ZP:%.+]] = tensor.collapse_shape %arg4 [] : tensor<1xui8> into tensor<ui8>
1606+
// CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xui8>
1607+
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]], [[ARG1]], [[ARG2]], [[INPUT_ZP]], [[OUTPUT_ZP]] : tensor<2xi8>, tensor<2xi32>, tensor<2xi8>, tensor<i8>, tensor<ui8>) outs([[INIT]] : tensor<2xui8>) {
1608+
// CHECK: ^bb0([[ARG0:%.*]]: i8, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i8, [[ARG3:%.*]]: i8, [[ARG4:%.*]]: ui8, [[OUT:%.*]]: ui8):
1609+
// CHECK: [[INPUT_ZP_I32:%.+]] = arith.extsi [[ARG3]] : i8 to i32
1610+
// CHECK: [[INPUT_ZP_I8:%.+]] = builtin.unrealized_conversion_cast [[ARG4]] : ui8 to i8
1611+
// CHECK: [[OUTPUT_ZP_I32:%.+]] = arith.extui [[INPUT_ZP_I8]] : i8 to i32
1612+
// CHECK: [[ARG0_I32:%.+]] = arith.extsi [[ARG0]] : i8 to i32
1613+
// CHECK: [[TMP1:%.+]] = arith.subi [[ARG0_I32]], [[INPUT_ZP_I32]] : i32
1614+
// CHECK: [[TMP2:%.+]] = tosa.apply_scale [[TMP1]], [[ARG1]], [[ARG2]] {rounding_mode = DOUBLE_ROUND} : (i32, i32, i8) -> i32
1615+
// CHECK: [[TMP3:%.+]] = arith.addi [[TMP2]], [[OUTPUT_ZP_I32]] : i32
1616+
// CHECK: %c0_i32 = arith.constant 0 : i32
1617+
// CHECK: %c255_i32 = arith.constant 255 : i32
1618+
// CHECK: [[MAX:%.+]] = arith.maxsi %c0_i32, [[TMP3]] : i32
1619+
// CHECK: [[MIN:%.+]] = arith.minsi %c255_i32, [[MAX]] : i32
1620+
%0 = tosa.rescale %arg0, %arg1, %arg2, %input_zp, %output_zp {scale32 = true, rounding_mode = DOUBLE_ROUND, per_channel = true, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<2xi32>, tensor<2xi8>, tensor<1xi8>, tensor<1xui8>) -> tensor<2xui8>
1621+
return %0 : tensor<2xui8>
1622+
}
1623+
1624+
// -----
1625+
15971626
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
15981627

15991628
// CHECK-LABEL: @reverse

0 commit comments

Comments
 (0)