@@ -672,14 +672,15 @@ std::optional<Value> floorIntDiv(PatternRewriter &rewriter, Operation *op,
672672 auto boolType =
673673 RankedTensorType::get (outType.getShape (), rewriter.getIntegerType (1 ));
674674
675- auto lhsMulRhs = rewriter.create <tosa::MulOp>(op->getLoc (), i32Type, lhs, rhs,
676- /* shift=*/ 0 );
675+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
676+ auto lhsMulRhs =
677+ rewriter.create <tosa::MulOp>(op->getLoc (), i32Type, lhs, rhs, mulShift);
677678
678679 auto lhsRhsDifferentSign =
679680 rewriter.create <tosa::GreaterOp>(op->getLoc (), boolType, zero, lhsMulRhs);
680681
681682 auto truncMulRhs = rewriter.create <tosa::MulOp>(op->getLoc (), i32Type,
682- intDivOp, rhs, /* shift= */ 0 );
683+ intDivOp, rhs, mulShift );
683684
684685 auto truncMulRhsEqualLhs =
685686 rewriter.create <tosa::EqualOp>(op->getLoc (), boolType, truncMulRhs, lhs);
@@ -918,9 +919,10 @@ LogicalResult ConvertAtenOp<AtenLeakyReluOp>::matchAndRewrite(
918919 op->getLoc (),
919920 RankedTensorType::get (selfTy.getShape (), rewriter.getIntegerType (1 )),
920921 self, zero);
922+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
921923 auto mulTensor = rewriter.create <tosa::MulOp>(
922924 op->getLoc (), getTypeConverter ()->convertType (op.getType ()), self,
923- alphaTensor, /* shift= */ 0 );
925+ alphaTensor, mulShift );
924926
925927 rewriter.replaceOpWithNewOp <tosa::SelectOp>(
926928 op, getTypeConverter ()->convertType (op.getType ()), cond, self, mulTensor);
@@ -2348,8 +2350,10 @@ LogicalResult ConvertAtenOp<AtenRsubScalarOp>::matchAndRewrite(
23482350 return rewriter.notifyMatchFailure (
23492351 op, " Failed to equalize ranks among operands and result" );
23502352
2351- auto multTensor = rewriter.create <tosa::MulOp>(op->getLoc (), resultTy, self,
2352- alphaTensor, /* shift=*/ 0 );
2353+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
2354+ auto multTensor =
2355+ rewriter.create <tosa::MulOp>(op->getLoc (), resultTy, self, alphaTensor,
2356+ mulShift);
23532357
23542358 rewriter.replaceOpWithNewOp <tosa::SubOp>(op, resultTy, otherTensor,
23552359 multTensor);
@@ -2761,12 +2765,14 @@ std::optional<Value> computeBatchNorm(Operation *op,
27612765 auto op3RsqrtOp2 = rewriter.create <tosa::RsqrtOp>(
27622766 op->getLoc (), variance.getType (), op2AddVarEpsilon.getResult ());
27632767
2764- auto op4MulOp1Op3 = rewriter.create <tosa::MulOp>(op->getLoc (), outType,
2765- op1SubInputMean.getResult (),
2766- op3RsqrtOp2.getResult (), 0 );
2768+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
2769+ auto op4MulOp1Op3 =
2770+ rewriter.create <tosa::MulOp>(op->getLoc (), outType,
2771+ op1SubInputMean.getResult (),
2772+ op3RsqrtOp2.getResult (), mulShift);
27672773
27682774 auto op5MulOp4Scale = rewriter.create <tosa::MulOp>(
2769- op->getLoc (), outType, op4MulOp1Op3.getResult (), weight, 0 );
2775+ op->getLoc (), outType, op4MulOp1Op3.getResult (), weight, mulShift );
27702776
27712777 return rewriter
27722778 .create <tosa::AddOp>(op->getLoc (), outType, op5MulOp4Scale.getResult (),
@@ -2986,22 +2992,24 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
29862992 auto bcastOutType =
29872993 RankedTensorType::get (makeShapeLLVMCompatible (bcastOutShape), elemTy);
29882994
2995+ auto mulShift =
2996+ tosa::getTosaMulShiftConstTensor (rewriter, op.getOperation (), 0 );
29892997 // Compute mean.
29902998 Value sum =
29912999 computeSumAndReshape (input, inputType, bcastOutType, bcastOutShape);
29923000 Value meanVal = rewriter.create <tosa::MulOp>(op.getLoc (), bcastOutType, sum,
2993- elemCntRcp, /* shift= */ 0 );
3001+ elemCntRcp, mulShift );
29943002
29953003 // Compute variance.
29963004 Value squareSumSub =
29973005 rewriter.create <tosa::SubOp>(op.getLoc (), inputType, input, meanVal);
2998- Value squareSum = rewriter.create <tosa::MulOp>(op. getLoc (), inputType,
2999- squareSumSub, squareSumSub, 0 );
3006+ Value squareSum = rewriter.create <tosa::MulOp>(
3007+ op. getLoc (), inputType, squareSumSub, squareSumSub, mulShift );
30003008
30013009 Value squareSumReduced =
30023010 computeSumAndReshape (squareSum, inputType, bcastOutType, bcastOutShape);
30033011 Value varianceVal = rewriter.create <tosa::MulOp>(
3004- op.getLoc (), bcastOutType, squareSumReduced, elemCntRcp, /* shift= */ 0 );
3012+ op.getLoc (), bcastOutType, squareSumReduced, elemCntRcp, mulShift );
30053013
30063014 // Reshape weight and bias.
30073015 SmallVector<int64_t > weightAndBiasBcastShape;
@@ -3259,8 +3267,9 @@ LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
32593267 rewriter.create <tosa::ReciprocalOp>(op.getLoc (), ln2Op.getType (), ln2Op);
32603268
32613269 auto logOp = rewriter.create <tosa::LogOp>(op.getLoc (), outType, self);
3270+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
32623271 rewriter.replaceOpWithNewOp <tosa::MulOp>(op, outType, logOp, rcpOp,
3263- /* shift= */ 0 );
3272+ mulShift );
32643273
32653274 return success ();
32663275}
@@ -3497,26 +3506,31 @@ approximateErfOp(ConversionPatternRewriter &rewriter, Operation *op, Value x,
34973506 mlir::tosa::EqualizeRanks (rewriter, op->getLoc (), x, a4).failed ())
34983507 return std::nullopt ;
34993508
3500- auto a1X = rewriter.create <tosa::MulOp>(loc, outType, a1, absX, /* shift=*/ 0 );
3509+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
3510+ auto a1X =
3511+ rewriter.create <tosa::MulOp>(loc, outType, a1, absX, mulShift);
35013512 auto sum = rewriter.create <tosa::AddOp>(loc, outType, a1X, one);
35023513
3503- auto x2 = rewriter.create <tosa::MulOp>(loc, outType, absX, absX, /* shift=*/ 0 );
3504- auto a2X = rewriter.create <tosa::MulOp>(loc, outType, a2, x2, /* shift=*/ 0 );
3514+ auto x2 =
3515+ rewriter.create <tosa::MulOp>(loc, outType, absX, absX, mulShift);
3516+ auto a2X = rewriter.create <tosa::MulOp>(loc, outType, a2, x2, mulShift);
35053517 sum = rewriter.create <tosa::AddOp>(loc, outType, sum, a2X);
35063518
3507- auto x3 = rewriter.create <tosa::MulOp>(loc, outType, x2, absX, /* shift=*/ 0 );
3508- auto a3X = rewriter.create <tosa::MulOp>(loc, outType, a3, x3, /* shift=*/ 0 );
3519+ auto x3 =
3520+ rewriter.create <tosa::MulOp>(loc, outType, x2, absX, mulShift);
3521+ auto a3X = rewriter.create <tosa::MulOp>(loc, outType, a3, x3, mulShift);
35093522 sum = rewriter.create <tosa::AddOp>(loc, outType, sum, a3X);
35103523
3511- auto x4 = rewriter.create <tosa::MulOp>(loc, outType, x3, absX, /* shift=*/ 0 );
3512- auto a4X = rewriter.create <tosa::MulOp>(loc, outType, a4, x4, /* shift=*/ 0 );
3524+ auto x4 =
3525+ rewriter.create <tosa::MulOp>(loc, outType, x3, absX, mulShift);
3526+ auto a4X = rewriter.create <tosa::MulOp>(loc, outType, a4, x4, mulShift);
35133527 sum = rewriter.create <tosa::AddOp>(loc, outType, sum, a4X);
35143528
35153529 auto rcprl = rewriter.create <tosa::ReciprocalOp>(loc, outType, sum);
35163530 auto rcprl2 =
3517- rewriter.create <tosa::MulOp>(loc, outType, rcprl, rcprl, /* shift= */ 0 );
3531+ rewriter.create <tosa::MulOp>(loc, outType, rcprl, rcprl, mulShift );
35183532 auto rcprl4 =
3519- rewriter.create <tosa::MulOp>(loc, outType, rcprl2, rcprl2, /* shift= */ 0 );
3533+ rewriter.create <tosa::MulOp>(loc, outType, rcprl2, rcprl2, mulShift );
35203534 auto erf = rewriter.create <tosa::SubOp>(loc, outType, one, rcprl4);
35213535
35223536 // Deal with negative x.
@@ -3553,13 +3567,14 @@ buildUnitNormalCdf(ConversionPatternRewriter &rewriter, Operation *op, Value x,
35533567 auto mean = zero;
35543568 Value xMinusMean = rewriter.create <tosa::SubOp>(loc, outType, x, mean);
35553569
3556- Value erfArg = rewriter.create <tosa::MulOp>(loc, outType, xMinusMean, rsqrt2,
3557- /* shift=*/ 0 );
3570+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
3571+ Value erfArg =
3572+ rewriter.create <tosa::MulOp>(loc, outType, xMinusMean, rsqrt2, mulShift);
35583573 Value erf = approximateErfOp (rewriter, op, erfArg, dtype).value ();
35593574 Value erfPlus1 = rewriter.create <tosa::AddOp>(loc, outType, one, erf);
35603575
3561- Value normalCdf = rewriter. create <tosa::MulOp>(loc, outType, oneHalf,
3562- erfPlus1, /* shift= */ 0 );
3576+ Value normalCdf =
3577+ rewriter. create <tosa::MulOp>(loc, outType, oneHalf, erfPlus1, mulShift );
35633578 return normalCdf;
35643579}
35653580
@@ -3599,8 +3614,9 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
35993614 op->getLoc (),
36003615 cast<RankedTensorType>(cdf.getType ()).cloneWith ({}, selfElemTy), cdf);
36013616
3617+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
36023618 rewriter.replaceOpWithNewOp <tosa::MulOp>(op, resultType, self, cdf,
3603- /* shift= */ 0 );
3619+ mulShift );
36043620 } else if (approximate.compare (" tanh" ) == 0 ) {
36053621 // "tanh" approximate
36063622 // GELU(x) = 0.5 * x * (1 + Tanh(sqrt(2/pi) * (x + 0.044715 * x^3))
@@ -3644,8 +3660,9 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
36443660 .value ();
36453661
36463662 // 0.5 * x
3663+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
36473664 auto halfInput = rewriter.create <tosa::MulOp>(op->getLoc (), resultType,
3648- half, self, /* shift= */ 0 );
3665+ half, self, mulShift );
36493666
36503667 // sqrt(2/pi)
36513668 auto sqrtTwoOverPi =
@@ -3658,7 +3675,7 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
36583675 // 0.044715 * x^3
36593676 auto inputPowThreeMul =
36603677 rewriter.create <tosa::MulOp>(op->getLoc (), resultType, magicNumber,
3661- inputPowThree.getResult (), /* shift= */ 0 );
3678+ inputPowThree.getResult (), mulShift );
36623679
36633680 // x + 0.044715 * x^3
36643681 auto inputPowThreeMulAdd = rewriter.create <tosa::AddOp>(
@@ -3667,7 +3684,7 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
36673684 // sqrt(2/pi) * (x + 0.044715 * x^3)
36683685 auto sqrtTwoOverPiMul = rewriter.create <tosa::MulOp>(
36693686 op->getLoc (), resultType, sqrtTwoOverPi.getResult (),
3670- inputPowThreeMulAdd.getResult (), /* shift= */ 0 );
3687+ inputPowThreeMulAdd.getResult (), mulShift );
36713688
36723689 // tanh(sqrt(2/pi) * (x + 0.044715 * x^3))
36733690 auto tanh = rewriter.create <tosa::TanhOp>(op->getLoc (), resultType,
@@ -3679,7 +3696,7 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
36793696
36803697 rewriter.replaceOpWithNewOp <tosa::MulOp>(
36813698 op, resultType, halfInput.getResult (), tanhAdd.getResult (),
3682- /* shift= */ 0 );
3699+ mulShift );
36833700 } else {
36843701 return rewriter.notifyMatchFailure (op,
36853702 " Unsupported approximation algorithm" );
@@ -3732,23 +3749,23 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
37323749 return rewriter.notifyMatchFailure (
37333750 op, " Failed to equalize ranks among operands and result" );
37343751
3752+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
37353753 Value inputSquared =
3736- rewriter.create <tosa::MulOp>(loc, selfType, self, self, /* shift= */ 0 );
3754+ rewriter.create <tosa::MulOp>(loc, selfType, self, self, mulShift );
37373755 Value negHalfInputSquared = rewriter.create <tosa::MulOp>(
3738- loc, selfType, inputSquared, negOneHalf, /* shift= */ 0 );
3756+ loc, selfType, inputSquared, negOneHalf, mulShift );
37393757 Value dinput =
37403758 rewriter.create <tosa::ExpOp>(loc, selfType, negHalfInputSquared);
37413759 Value cdf = buildUnitNormalCdf (rewriter, op, self, selfElemTy).value ();
37423760 Value dinputInput =
3743- rewriter.create <tosa::MulOp>(loc, selfType, dinput, self, /* shift= */ 0 );
3761+ rewriter.create <tosa::MulOp>(loc, selfType, dinput, self, mulShift );
37443762 Value dinputInputAlpha = rewriter.create <tosa::MulOp>(
3745- loc, selfType, dinputInput, kAlphaHalf , /* shift= */ 0 );
3763+ loc, selfType, dinputInput, kAlphaHalf , mulShift );
37463764 Value cdfExt =
37473765 rewriter.create <tosa::AddOp>(loc, selfType, dinputInputAlpha, cdf);
37483766 rewriter.replaceOpWithNewOp <tosa::MulOp>(
37493767 op, getTypeConverter ()->convertType (op.getType ()),
3750- adaptor.getGradOutput (), cdfExt,
3751- /* shift=*/ 0 );
3768+ adaptor.getGradOutput (), cdfExt, mulShift);
37523769
37533770 return success ();
37543771}
@@ -5232,8 +5249,9 @@ LogicalResult ConvertAtenOp<AtenIscloseOp>::matchAndRewrite(
52325249 rewriter.create <tosa::AbsOp>(op->getLoc (), selfType, rhsSubOp);
52335250
52345251 auto lhsAbsOp = rewriter.create <tosa::AbsOp>(op->getLoc (), otherType, other);
5252+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
52355253 auto mulOp = rewriter.create <tosa::MulOp>(op->getLoc (), otherType,
5236- rtolConstOp, lhsAbsOp, /* shift= */ 0 );
5254+ rtolConstOp, lhsAbsOp, mulShift );
52375255 auto addOp =
52385256 rewriter.create <tosa::AddOp>(op->getLoc (), otherType, atolConstOp, mulOp);
52395257
@@ -5778,8 +5796,9 @@ class ConvertAtenRemainderFmodOp : public OpConversionPattern<AtenOpT> {
57785796 if (isa<mlir::FloatType>(outElemTy)) {
57795797 auto otherTensorReciprocal = rewriter.create <tosa::ReciprocalOp>(
57805798 op.getLoc (), otherTensor.getType (), otherTensor);
5799+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
57815800 divTensor = rewriter.create <tosa::MulOp>(
5782- op.getLoc (), outType, self, otherTensorReciprocal, /* shift= */ 0 );
5801+ op.getLoc (), outType, self, otherTensorReciprocal, mulShift );
57835802 divTensor =
57845803 rewriter.create <tosa::FloorOp>(op.getLoc (), outType, divTensor);
57855804 } else {
@@ -5804,9 +5823,10 @@ class ConvertAtenRemainderFmodOp : public OpConversionPattern<AtenOpT> {
58045823 }
58055824 }
58065825
5807- auto mulTensor = rewriter.create <tosa::MulOp>(op.getLoc (), outType,
5808- otherTensor, divTensor,
5809- /* shift=*/ 0 );
5826+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
5827+ auto mulTensor =
5828+ rewriter.create <tosa::MulOp>(op.getLoc (), outType, otherTensor,
5829+ divTensor, mulShift);
58105830 rewriter.replaceOpWithNewOp <tosa::SubOp>(op, outType, self, mulTensor);
58115831
58125832 return success ();
@@ -7010,8 +7030,9 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
70107030 return rewriter.notifyMatchFailure (
70117031 op, " Failed to equalize ranks among operands and result" );
70127032
7033+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
70137034 rewriter.replaceOpWithNewOp <tosa::MulOp>(op, resultType, self, trilMask,
7014- /* shift= */ 0 );
7035+ mulShift );
70157036
70167037 return success ();
70177038}
@@ -7106,15 +7127,16 @@ LogicalResult ConvertAtenOp<AtenRoundOp>::matchAndRewrite(
71067127
71077128 auto ceilInput = rewriter.create <tosa::CeilOp>(op->getLoc (), resultTy, self);
71087129
7130+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
71097131 auto floorInputDivByTwo = rewriter.create <tosa::MulOp>(
7110- op->getLoc (), resultTy, floorInput.getResult (), oneHalf, /* shift= */ 0 );
7132+ op->getLoc (), resultTy, floorInput.getResult (), oneHalf, mulShift );
71117133
71127134 auto floorDivResult = rewriter.create <tosa::FloorOp>(
71137135 op->getLoc (), resultTy, floorInputDivByTwo.getResult ());
71147136
71157137 // (floor(input) // 2) * 2
71167138 auto evenComparison = rewriter.create <tosa::MulOp>(
7117- op->getLoc (), resultTy, floorDivResult.getResult (), two, /* shift= */ 0 );
7139+ op->getLoc (), resultTy, floorDivResult.getResult (), two, mulShift );
71187140
71197141 // floor(input) // 2) * 2 == input <=> floor(input) % 2 == 0
71207142 auto floorInputEven = rewriter.create <tosa::EqualOp>(
@@ -7296,9 +7318,10 @@ LogicalResult ConvertAtenOp<AtenDiagonalOp>::matchAndRewrite(
72967318 return rewriter.notifyMatchFailure (
72977319 op, " Failed to equalize ranks among operands and result" );
72987320
7321+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
72997322 Value diagonalTensor = rewriter.create <tosa::MulOp>(
73007323 op->getLoc (), transposedInputType, selfTransposed, diagonalMask,
7301- /* shift= */ 0 );
7324+ mulShift );
73027325
73037326 auto resultShape = makeShapeTorchCompatible (resultType.getShape ());
73047327 auto targetReduceDim = resultShape[resultType.getRank () - 1 ];
@@ -8587,9 +8610,10 @@ LogicalResult ConvertAtenOp<AtenLogitOp>::matchAndRewrite(
85878610 auto oneMinusZiReciprocal = rewriter.create <tosa::ReciprocalOp>(
85888611 op->getLoc (), resultType, oneMinusZi.getResult ());
85898612
8613+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
85908614 auto mulOp = rewriter.create <tosa::MulOp>(op->getLoc (), resultType, zi,
85918615 oneMinusZiReciprocal.getResult (),
8592- /* shift= */ 0 );
8616+ mulShift );
85938617
85948618 auto result =
85958619 rewriter.create <tosa::LogOp>(op->getLoc (), resultType, mulOp.getResult ());
@@ -8687,9 +8711,11 @@ LogicalResult ConvertAtenOp<AtenLog10Op>::matchAndRewrite(
86878711 auto reciprocalOp = rewriter.create <tosa::ReciprocalOp>(
86888712 op->getLoc (), constTenType, logOfTen.getResult ());
86898713
8690- auto result = rewriter.create <tosa::MulOp>(
8691- op->getLoc (), resultType, logOfSelf.getResult (), reciprocalOp.getResult (),
8692- /* shift=*/ 0 );
8714+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
8715+ auto result =
8716+ rewriter.create <tosa::MulOp>(op->getLoc (), resultType,
8717+ logOfSelf.getResult (),
8718+ reciprocalOp.getResult (), mulShift);
86938719
86948720 rewriter.replaceOp (op, {result.getResult ()});
86958721
@@ -8772,9 +8798,10 @@ LogicalResult ConvertAtenOp<AtenTanOp>::matchAndRewrite(
87728798 auto reciprocalOp =
87738799 rewriter.create <tosa::ReciprocalOp>(op->getLoc (), resultType, cosOp);
87748800
8801+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
87758802 auto result = rewriter.create <tosa::MulOp>(
87768803 op->getLoc (), resultType, sinOp.getResult (), reciprocalOp.getResult (),
8777- /* shift= */ 0 );
8804+ mulShift );
87788805
87798806 rewriter.replaceOp (op, {result.getResult ()});
87808807
0 commit comments