@@ -1484,49 +1484,25 @@ static void setupLinalgGenericOpInputAndIndexingMap(
14841484}
14851485
14861486// Return the extended Zp to be used in subsequent arithmetic operations.
1487- static Value getExtendInputZp (OpBuilder &builder, Type valueTy,
1488- FailureOr<int64_t > maybeZp, Location loc,
1489- ValueRange blockArgs, int64_t iZpArg) {
1487+ static Value getExtendZp (OpBuilder &builder, Type valueTy,
1488+ FailureOr<int64_t > maybeZp, Location loc,
1489+ ValueRange blockArgs, int64_t zpArg,
1490+ bool isOutputZp = false ) {
14901491 Value result;
1492+ const int32_t bitwidth = valueTy.getIntOrFloatBitWidth ();
1493+ const uint32_t attrBitwidth =
1494+ isOutputZp ? 32 : (bitwidth > 32 ? bitwidth : 32 );
1495+ auto extendType = builder.getIntegerType (attrBitwidth);
14911496 // The Zp value can be either constant or non-constant, depending on
14921497 // whether dynamic extension is enabled.
14931498 // If 'maybeZp' fails, it indicates that Zp is non-constant and will
14941499 // be passed as an input to linalg::GenericOp.
14951500 if (failed (maybeZp)) {
1496- result = blockArgs[iZpArg ];
1501+ result = blockArgs[zpArg ];
14971502 auto zpTy = result.getType ();
1498- if (zpTy.getIntOrFloatBitWidth () < 32 ) {
1499- if (zpTy.isUnsignedInteger ()) {
1500- return builder.create <arith::ExtUIOp>(loc, builder.getI32Type (),
1501- result);
1502- } else {
1503- return builder.create <arith::ExtSIOp>(loc, builder.getI32Type (),
1504- result);
1505- }
1506- }
1507- } else {
1508- const int32_t bitwidth = valueTy.getIntOrFloatBitWidth ();
1509- // Extend zeropoint for sub-32bits widths.
1510- const int32_t attrBitwidth = bitwidth > 32 ? bitwidth : 32 ;
1511- return builder.create <arith::ConstantOp>(
1512- loc, IntegerAttr::get (builder.getIntegerType (attrBitwidth), *maybeZp));
1513- }
1514- return result;
1515- }
1516-
1517- // Return the i32 outputZp to be used in subsequent arithmetic operations.
1518- static Value getI32OutputZp (OpBuilder &builder, Type valueTy,
1519- FailureOr<int64_t > maybeZp, Location loc,
1520- ValueRange blockArgs, int64_t oZpArg) {
1521- Value result;
1522- // The Zp value can be either constant or non-constant, depending on
1523- // whether dynamic extension is enabled.
1524- // If 'maybeZp' fails, it indicates that Zp is non-constant and will
1525- // be passed as an input to linalg::GenericOp.
1526- if (failed (maybeZp)) {
1527- result = blockArgs[oZpArg];
1528- auto zpTy = result.getType ();
1529- if (zpTy.getIntOrFloatBitWidth () < 32 ) {
1503+ if (zpTy.getIntOrFloatBitWidth () < attrBitwidth) {
1504+ // For ExtUIOp, the input must be signless.
1505+ // UnrealizedConversionCastOp will cast the input to signless type.
15301506 if (zpTy.isUnsignedInteger ()) {
15311507 result =
15321508 UnrealizedConversionCastOp::create (
@@ -1535,16 +1511,14 @@ static Value getI32OutputZp(OpBuilder &builder, Type valueTy,
15351511 .getResult (0 );
15361512 }
15371513 if (zpTy.isUnsignedInteger ()) {
1538- return builder.create <arith::ExtUIOp>(loc, builder.getI32Type (),
1539- result);
1514+ return builder.create <arith::ExtUIOp>(loc, extendType, result);
15401515 } else {
1541- return builder.create <arith::ExtSIOp>(loc, builder.getI32Type (),
1542- result);
1516+ return builder.create <arith::ExtSIOp>(loc, extendType, result);
15431517 }
15441518 }
15451519 } else {
15461520 return builder.create <arith::ConstantOp>(
1547- loc, IntegerAttr::get (builder. getIntegerType ( 32 ) , *maybeZp));
1521+ loc, IntegerAttr::get (extendType , *maybeZp));
15481522 }
15491523 return result;
15501524}
@@ -1687,12 +1661,12 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
16871661 Type valueTy = value.getType ();
16881662
16891663 FailureOr<int64_t > maybeIZp = op.getInputZeroPoint ();
1690- auto inputZp = getExtendInputZp (nestedBuilder, valueTy, maybeIZp,
1691- nestedLoc, blockArgs, iZpArg);
1664+ auto inputZp = getExtendZp (nestedBuilder, valueTy, maybeIZp,
1665+ nestedLoc, blockArgs, iZpArg);
16921666
16931667 FailureOr<int64_t > maybeOZp = op.getOutputZeroPoint ();
1694- auto outputZp = getI32OutputZp (nestedBuilder, valueTy, maybeOZp,
1695- nestedLoc, blockArgs, oZpArg);
1668+ auto outputZp = getExtendZp (nestedBuilder, valueTy, maybeOZp,
1669+ nestedLoc, blockArgs, oZpArg, true );
16961670
16971671 IntegerType outIntType =
16981672 cast<IntegerType>(blockArgs.back ().getType ());
0 commit comments