@@ -82,13 +82,16 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
8282 rhsOrResult);
8383}
8484
85- template < typename T>
85+ // Create i32 Value from zp, a zero-point value sign-extended from bitwidth.
8686static arith::ConstantOp
87- createConstOpFromZpVal (Operation *op, const int64_t &zp, Type requiredAttrType,
88- OpBuilder &rewriter) {
89- auto castedN = static_cast <T>(zp);
87+ createConstOpFromSExtZp (int64_t zp, unsigned zpBitwidth, unsigned attrBitwidth,
88+ bool isSigned, Location loc, OpBuilder &rewriter) {
89+
90+ // Zero the signed-extended bits if isSigned is false.
91+ zp = isSigned ? zp : zp & ((1 << zpBitwidth) - 1 );
92+
9093 return rewriter.create <arith::ConstantOp>(
91- op-> getLoc () , IntegerAttr::get (requiredAttrType, castedN ));
94+ loc , IntegerAttr::get (rewriter. getIntegerType (attrBitwidth), zp ));
9295}
9396
9497static Value createLinalgBodyCalculationForElementwiseOp (
@@ -1467,20 +1470,17 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
14671470 Value value = blockArgs[0 ];
14681471 Type valueTy = value.getType ();
14691472
1470- // For now we do all of our math in 64-bit. This is not optimal but
1471- // should be correct for now, consider computing correct bit depth
1472- // later.
1473- int32_t inBitwidth = valueTy.getIntOrFloatBitWidth () > 32 ? 48 : 32 ;
1474-
14751473 FailureOr<int64_t > maybeIZp = op.getInputZeroPoint ();
14761474 if (failed (maybeIZp)) {
14771475 (void )rewriter.notifyMatchFailure (
14781476 op, " input zero point cannot be statically determined" );
14791477 return ;
14801478 }
14811479
1482- auto inputZp = createConstOpFromZpVal<int32_t >(
1483- op, *maybeIZp, nestedBuilder.getIntegerType (inBitwidth),
1480+ const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth ();
1481+ const int32_t attrBitwidth = inBitwidth > 32 ? 48 : 32 ;
1482+ auto inputZp = createConstOpFromSExtZp (
1483+ *maybeIZp, inBitwidth, attrBitwidth, !op.getInputUnsigned (), loc,
14841484 nestedBuilder);
14851485
14861486 FailureOr<int64_t > maybeOZp = op.getOutputZeroPoint ();
@@ -1490,16 +1490,12 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
14901490 return ;
14911491 };
14921492
1493- // pre-process OutputZP as it can be unsigned
1494- auto outBitwidth = outputTy.getElementType ().getIntOrFloatBitWidth ();
1495- APInt OZp (outBitwidth, !op.getOutputUnsigned ());
1496- OZp = static_cast <int64_t >(*maybeOZp);
1497- *maybeOZp = op.getOutputUnsigned ()
1498- ? static_cast <int64_t >(OZp.getZExtValue ())
1499- : OZp.getSExtValue ();
1500-
1501- auto outputZp = createConstOpFromZpVal<int32_t >(
1502- op, *maybeOZp, nestedBuilder.getI32Type (), nestedBuilder);
1493+ IntegerType outIntType =
1494+ cast<IntegerType>(blockArgs.back ().getType ());
1495+ unsigned outBitWidth = outIntType.getWidth ();
1496+ auto outputZp = createConstOpFromSExtZp (
1497+ *maybeOZp, outBitWidth, /* attrBitwidth=*/ 32 ,
1498+ !op.getOutputUnsigned (), loc, nestedBuilder);
15031499
15041500 Value multiplier = multiplierConstant ? multiplierConstant
15051501 : blockArgs[multiplierArg];
@@ -1527,10 +1523,6 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
15271523 nestedBuilder.create <arith::AddIOp>(nestedLoc, value, outputZp);
15281524
15291525 // Saturate to the output size.
1530- IntegerType outIntType =
1531- cast<IntegerType>(blockArgs.back ().getType ());
1532- unsigned outBitWidth = outIntType.getWidth ();
1533-
15341526 int32_t intMin = APInt::getSignedMinValue (outBitWidth).getSExtValue ();
15351527 int32_t intMax = APInt::getSignedMaxValue (outBitWidth).getSExtValue ();
15361528
0 commit comments