Skip to content

Commit d9bcd1e

Browse files
committed
Introduce getExtendZp for both inputZp and outputZp
1 parent 8acc72c commit d9bcd1e

File tree

2 files changed

+29
-54
lines changed

2 files changed

+29
-54
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 19 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1484,49 +1484,25 @@ static void setupLinalgGenericOpInputAndIndexingMap(
14841484
}
14851485

14861486
// Return the extended Zp to be used in subsequent arithmetic operations.
1487-
static Value getExtendInputZp(OpBuilder &builder, Type valueTy,
1488-
FailureOr<int64_t> maybeZp, Location loc,
1489-
ValueRange blockArgs, int64_t iZpArg) {
1487+
static Value getExtendZp(OpBuilder &builder, Type valueTy,
1488+
FailureOr<int64_t> maybeZp, Location loc,
1489+
ValueRange blockArgs, int64_t zpArg,
1490+
bool isOutputZp = false) {
14901491
Value result;
1492+
const int32_t bitwidth = valueTy.getIntOrFloatBitWidth();
1493+
const uint32_t attrBitwidth =
1494+
isOutputZp ? 32 : (bitwidth > 32 ? bitwidth : 32);
1495+
auto extendType = builder.getIntegerType(attrBitwidth);
14911496
// The Zp value can be either constant or non-constant, depending on
14921497
// whether dynamic extension is enabled.
14931498
// If 'maybeZp' fails, it indicates that Zp is non-constant and will
14941499
// be passed as an input to linalg::GenericOp.
14951500
if (failed(maybeZp)) {
1496-
result = blockArgs[iZpArg];
1501+
result = blockArgs[zpArg];
14971502
auto zpTy = result.getType();
1498-
if (zpTy.getIntOrFloatBitWidth() < 32) {
1499-
if (zpTy.isUnsignedInteger()) {
1500-
return builder.create<arith::ExtUIOp>(loc, builder.getI32Type(),
1501-
result);
1502-
} else {
1503-
return builder.create<arith::ExtSIOp>(loc, builder.getI32Type(),
1504-
result);
1505-
}
1506-
}
1507-
} else {
1508-
const int32_t bitwidth = valueTy.getIntOrFloatBitWidth();
1509-
// Extend zeropoint for sub-32bits widths.
1510-
const int32_t attrBitwidth = bitwidth > 32 ? bitwidth : 32;
1511-
return builder.create<arith::ConstantOp>(
1512-
loc, IntegerAttr::get(builder.getIntegerType(attrBitwidth), *maybeZp));
1513-
}
1514-
return result;
1515-
}
1516-
1517-
// Return the i32 outputZp to be used in subsequent arithmetic operations.
1518-
static Value getI32OutputZp(OpBuilder &builder, Type valueTy,
1519-
FailureOr<int64_t> maybeZp, Location loc,
1520-
ValueRange blockArgs, int64_t oZpArg) {
1521-
Value result;
1522-
// The Zp value can be either constant or non-constant, depending on
1523-
// whether dynamic extension is enabled.
1524-
// If 'maybeZp' fails, it indicates that Zp is non-constant and will
1525-
// be passed as an input to linalg::GenericOp.
1526-
if (failed(maybeZp)) {
1527-
result = blockArgs[oZpArg];
1528-
auto zpTy = result.getType();
1529-
if (zpTy.getIntOrFloatBitWidth() < 32) {
1503+
if (zpTy.getIntOrFloatBitWidth() < attrBitwidth) {
1504+
// For ExtUIOp, the input must be signless.
1505+
// UnrealizedConversionCastOp will cast the input to signless type.
15301506
if (zpTy.isUnsignedInteger()) {
15311507
result =
15321508
UnrealizedConversionCastOp::create(
@@ -1535,16 +1511,14 @@ static Value getI32OutputZp(OpBuilder &builder, Type valueTy,
15351511
.getResult(0);
15361512
}
15371513
if (zpTy.isUnsignedInteger()) {
1538-
return builder.create<arith::ExtUIOp>(loc, builder.getI32Type(),
1539-
result);
1514+
return builder.create<arith::ExtUIOp>(loc, extendType, result);
15401515
} else {
1541-
return builder.create<arith::ExtSIOp>(loc, builder.getI32Type(),
1542-
result);
1516+
return builder.create<arith::ExtSIOp>(loc, extendType, result);
15431517
}
15441518
}
15451519
} else {
15461520
return builder.create<arith::ConstantOp>(
1547-
loc, IntegerAttr::get(builder.getIntegerType(32), *maybeZp));
1521+
loc, IntegerAttr::get(extendType, *maybeZp));
15481522
}
15491523
return result;
15501524
}
@@ -1687,12 +1661,12 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
16871661
Type valueTy = value.getType();
16881662

16891663
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1690-
auto inputZp = getExtendInputZp(nestedBuilder, valueTy, maybeIZp,
1691-
nestedLoc, blockArgs, iZpArg);
1664+
auto inputZp = getExtendZp(nestedBuilder, valueTy, maybeIZp,
1665+
nestedLoc, blockArgs, iZpArg);
16921666

16931667
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1694-
auto outputZp = getI32OutputZp(nestedBuilder, valueTy, maybeOZp,
1695-
nestedLoc, blockArgs, oZpArg);
1668+
auto outputZp = getExtendZp(nestedBuilder, valueTy, maybeOZp,
1669+
nestedLoc, blockArgs, oZpArg, true);
16961670

16971671
IntegerType outIntType =
16981672
cast<IntegerType>(blockArgs.back().getType());

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1596,19 +1596,20 @@ func.func @rescale_no_const_per_channel(%arg0 : tensor<2xi8>, %arg1 : tensor<2xi
15961596

15971597
// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)>
15981598
// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> ()>
1599-
// CHECK-LABEL: @rescale_no_const_per_channel_output_zp_ui8
1599+
// CHECK-LABEL: @rescale_no_const_per_channel_input_output_zp_ui8
16001600
// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
16011601
// CHECK-SAME: [[ARG1:%[0-9a-zA-Z_]*]]
16021602
// CHECK-SAME: [[ARG2:%[0-9a-zA-Z_]*]]
1603-
func.func @rescale_no_const_per_channel_output_zp_ui8(%arg0 : tensor<2xi8>, %arg1 : tensor<2xi32>, %arg2 : tensor<2xi8>, %input_zp : tensor<1xi8>, %output_zp : tensor<1xui8>) -> (tensor<2xui8>) {
1604-
// CHECK: [[INPUT_ZP:%.+]] = tensor.collapse_shape %arg3 [] : tensor<1xi8> into tensor<i8>
1603+
func.func @rescale_no_const_per_channel_input_output_zp_ui8(%arg0 : tensor<2xi8>, %arg1 : tensor<2xi32>, %arg2 : tensor<2xi8>, %input_zp : tensor<1xui8>, %output_zp : tensor<1xui8>) -> (tensor<2xui8>) {
1604+
// CHECK: [[INPUT_ZP:%.+]] = tensor.collapse_shape %arg3 [] : tensor<1xui8> into tensor<ui8>
16051605
// CHECK: [[OUTPUT_ZP:%.+]] = tensor.collapse_shape %arg4 [] : tensor<1xui8> into tensor<ui8>
16061606
// CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xui8>
1607-
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]], [[ARG1]], [[ARG2]], [[INPUT_ZP]], [[OUTPUT_ZP]] : tensor<2xi8>, tensor<2xi32>, tensor<2xi8>, tensor<i8>, tensor<ui8>) outs([[INIT]] : tensor<2xui8>) {
1608-
// CHECK: ^bb0([[ARG0:%.*]]: i8, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i8, [[ARG3:%.*]]: i8, [[ARG4:%.*]]: ui8, [[OUT:%.*]]: ui8):
1609-
// CHECK: [[INPUT_ZP_I32:%.+]] = arith.extsi [[ARG3]] : i8 to i32
1610-
// CHECK: [[INPUT_ZP_I8:%.+]] = builtin.unrealized_conversion_cast [[ARG4]] : ui8 to i8
1611-
// CHECK: [[OUTPUT_ZP_I32:%.+]] = arith.extui [[INPUT_ZP_I8]] : i8 to i32
1607+
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]], [[ARG1]], [[ARG2]], [[INPUT_ZP]], [[OUTPUT_ZP]] : tensor<2xi8>, tensor<2xi32>, tensor<2xi8>, tensor<ui8>, tensor<ui8>) outs([[INIT]] : tensor<2xui8>) {
1608+
// CHECK: ^bb0([[ARG0:%.*]]: i8, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i8, [[ARG3:%.*]]: ui8, [[ARG4:%.*]]: ui8, [[OUT:%.*]]: ui8):
1609+
// CHECK: [[INPUT_ZP_I8:%.+]] = builtin.unrealized_conversion_cast [[ARG3]] : ui8 to i8
1610+
// CHECK: [[INPUT_ZP_I32:%.+]] = arith.extui [[INPUT_ZP_I8]] : i8 to i32
1611+
// CHECK: [[OUTPUT_ZP_I8:%.+]] = builtin.unrealized_conversion_cast [[ARG4]] : ui8 to i8
1612+
// CHECK: [[OUTPUT_ZP_I32:%.+]] = arith.extui [[OUTPUT_ZP_I8]] : i8 to i32
16121613
// CHECK: [[ARG0_I32:%.+]] = arith.extsi [[ARG0]] : i8 to i32
16131614
// CHECK: [[TMP1:%.+]] = arith.subi [[ARG0_I32]], [[INPUT_ZP_I32]] : i32
16141615
// CHECK: [[TMP2:%.+]] = tosa.apply_scale [[TMP1]], [[ARG1]], [[ARG2]] {rounding_mode = DOUBLE_ROUND} : (i32, i32, i8) -> i32
@@ -1617,7 +1618,7 @@ func.func @rescale_no_const_per_channel_output_zp_ui8(%arg0 : tensor<2xi8>, %arg
16171618
// CHECK: %c255_i32 = arith.constant 255 : i32
16181619
// CHECK: [[MAX:%.+]] = arith.maxsi %c0_i32, [[TMP3]] : i32
16191620
// CHECK: [[MIN:%.+]] = arith.minsi %c255_i32, [[MAX]] : i32
1620-
%0 = tosa.rescale %arg0, %arg1, %arg2, %input_zp, %output_zp {scale32 = true, rounding_mode = DOUBLE_ROUND, per_channel = true, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<2xi32>, tensor<2xi8>, tensor<1xi8>, tensor<1xui8>) -> tensor<2xui8>
1621+
%0 = tosa.rescale %arg0, %arg1, %arg2, %input_zp, %output_zp {scale32 = true, rounding_mode = DOUBLE_ROUND, per_channel = true, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<2xi32>, tensor<2xi8>, tensor<1xui8>, tensor<1xui8>) -> tensor<2xui8>
16211622
return %0 : tensor<2xui8>
16221623
}
16231624

0 commit comments

Comments
 (0)