@@ -672,14 +672,15 @@ std::optional<Value> floorIntDiv(PatternRewriter &rewriter, Operation *op,
672672 auto boolType =
673673 RankedTensorType::get (outType.getShape (), rewriter.getIntegerType (1 ));
674674
675- auto lhsMulRhs = rewriter.create <tosa::MulOp>(op->getLoc (), i32Type, lhs, rhs,
676- /* shift=*/ 0 );
675+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
676+ auto lhsMulRhs =
677+ rewriter.create <tosa::MulOp>(op->getLoc (), i32Type, lhs, rhs, mulShift);
677678
678679 auto lhsRhsDifferentSign =
679680 rewriter.create <tosa::GreaterOp>(op->getLoc (), boolType, zero, lhsMulRhs);
680681
681682 auto truncMulRhs = rewriter.create <tosa::MulOp>(op->getLoc (), i32Type,
682- intDivOp, rhs, /* shift= */ 0 );
683+ intDivOp, rhs, mulShift );
683684
684685 auto truncMulRhsEqualLhs =
685686 rewriter.create <tosa::EqualOp>(op->getLoc (), boolType, truncMulRhs, lhs);
@@ -918,9 +919,10 @@ LogicalResult ConvertAtenOp<AtenLeakyReluOp>::matchAndRewrite(
918919 op->getLoc (),
919920 RankedTensorType::get (selfTy.getShape (), rewriter.getIntegerType (1 )),
920921 self, zero);
922+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
921923 auto mulTensor = rewriter.create <tosa::MulOp>(
922924 op->getLoc (), getTypeConverter ()->convertType (op.getType ()), self,
923- alphaTensor, /* shift= */ 0 );
925+ alphaTensor, mulShift );
924926
925927 rewriter.replaceOpWithNewOp <tosa::SelectOp>(
926928 op, getTypeConverter ()->convertType (op.getType ()), cond, self, mulTensor);
@@ -2348,8 +2350,9 @@ LogicalResult ConvertAtenOp<AtenRsubScalarOp>::matchAndRewrite(
23482350 return rewriter.notifyMatchFailure (
23492351 op, " Failed to equalize ranks among operands and result" );
23502352
2353+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
23512354 auto multTensor = rewriter.create <tosa::MulOp>(op->getLoc (), resultTy, self,
2352- alphaTensor, /* shift= */ 0 );
2355+ alphaTensor, mulShift );
23532356
23542357 rewriter.replaceOpWithNewOp <tosa::SubOp>(op, resultTy, otherTensor,
23552358 multTensor);
@@ -2761,12 +2764,13 @@ std::optional<Value> computeBatchNorm(Operation *op,
27612764 auto op3RsqrtOp2 = rewriter.create <tosa::RsqrtOp>(
27622765 op->getLoc (), variance.getType (), op2AddVarEpsilon.getResult ());
27632766
2764- auto op4MulOp1Op3 = rewriter.create <tosa::MulOp>(op->getLoc (), outType,
2765- op1SubInputMean.getResult (),
2766- op3RsqrtOp2.getResult (), 0 );
2767+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
2768+ auto op4MulOp1Op3 = rewriter.create <tosa::MulOp>(
2769+ op->getLoc (), outType, op1SubInputMean.getResult (),
2770+ op3RsqrtOp2.getResult (), mulShift);
27672771
27682772 auto op5MulOp4Scale = rewriter.create <tosa::MulOp>(
2769- op->getLoc (), outType, op4MulOp1Op3.getResult (), weight, 0 );
2773+ op->getLoc (), outType, op4MulOp1Op3.getResult (), weight, mulShift );
27702774
27712775 return rewriter
27722776 .create <tosa::AddOp>(op->getLoc (), outType, op5MulOp4Scale.getResult (),
@@ -2986,22 +2990,24 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
29862990 auto bcastOutType =
29872991 RankedTensorType::get (makeShapeLLVMCompatible (bcastOutShape), elemTy);
29882992
2993+ auto mulShift =
2994+ tosa::getTosaMulShiftConstTensor (rewriter, op.getOperation (), 0 );
29892995 // Compute mean.
29902996 Value sum =
29912997 computeSumAndReshape (input, inputType, bcastOutType, bcastOutShape);
29922998 Value meanVal = rewriter.create <tosa::MulOp>(op.getLoc (), bcastOutType, sum,
2993- elemCntRcp, /* shift= */ 0 );
2999+ elemCntRcp, mulShift );
29943000
29953001 // Compute variance.
29963002 Value squareSumSub =
29973003 rewriter.create <tosa::SubOp>(op.getLoc (), inputType, input, meanVal);
2998- Value squareSum = rewriter.create <tosa::MulOp>(op. getLoc (), inputType,
2999- squareSumSub, squareSumSub, 0 );
3004+ Value squareSum = rewriter.create <tosa::MulOp>(
3005+ op. getLoc (), inputType, squareSumSub, squareSumSub, mulShift );
30003006
30013007 Value squareSumReduced =
30023008 computeSumAndReshape (squareSum, inputType, bcastOutType, bcastOutShape);
30033009 Value varianceVal = rewriter.create <tosa::MulOp>(
3004- op.getLoc (), bcastOutType, squareSumReduced, elemCntRcp, /* shift= */ 0 );
3010+ op.getLoc (), bcastOutType, squareSumReduced, elemCntRcp, mulShift );
30053011
30063012 // Reshape weight and bias.
30073013 SmallVector<int64_t > weightAndBiasBcastShape;
@@ -3259,8 +3265,8 @@ LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
32593265 rewriter.create <tosa::ReciprocalOp>(op.getLoc (), ln2Op.getType (), ln2Op);
32603266
32613267 auto logOp = rewriter.create <tosa::LogOp>(op.getLoc (), outType, self);
3262- rewriter. replaceOpWithNewOp < tosa::MulOp>(op, outType, logOp, rcpOp,
3263- /* shift= */ 0 );
3268+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
3269+ rewriter. replaceOpWithNewOp <tosa::MulOp>(op, outType, logOp, rcpOp, mulShift );
32643270
32653271 return success ();
32663272}
@@ -3497,26 +3503,27 @@ approximateErfOp(ConversionPatternRewriter &rewriter, Operation *op, Value x,
34973503 mlir::tosa::EqualizeRanks (rewriter, op->getLoc (), x, a4).failed ())
34983504 return std::nullopt ;
34993505
3500- auto a1X = rewriter.create <tosa::MulOp>(loc, outType, a1, absX, /* shift=*/ 0 );
3506+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
3507+ auto a1X = rewriter.create <tosa::MulOp>(loc, outType, a1, absX, mulShift);
35013508 auto sum = rewriter.create <tosa::AddOp>(loc, outType, a1X, one);
35023509
3503- auto x2 = rewriter.create <tosa::MulOp>(loc, outType, absX, absX, /* shift= */ 0 );
3504- auto a2X = rewriter.create <tosa::MulOp>(loc, outType, a2, x2, /* shift= */ 0 );
3510+ auto x2 = rewriter.create <tosa::MulOp>(loc, outType, absX, absX, mulShift );
3511+ auto a2X = rewriter.create <tosa::MulOp>(loc, outType, a2, x2, mulShift );
35053512 sum = rewriter.create <tosa::AddOp>(loc, outType, sum, a2X);
35063513
3507- auto x3 = rewriter.create <tosa::MulOp>(loc, outType, x2, absX, /* shift= */ 0 );
3508- auto a3X = rewriter.create <tosa::MulOp>(loc, outType, a3, x3, /* shift= */ 0 );
3514+ auto x3 = rewriter.create <tosa::MulOp>(loc, outType, x2, absX, mulShift );
3515+ auto a3X = rewriter.create <tosa::MulOp>(loc, outType, a3, x3, mulShift );
35093516 sum = rewriter.create <tosa::AddOp>(loc, outType, sum, a3X);
35103517
3511- auto x4 = rewriter.create <tosa::MulOp>(loc, outType, x3, absX, /* shift= */ 0 );
3512- auto a4X = rewriter.create <tosa::MulOp>(loc, outType, a4, x4, /* shift= */ 0 );
3518+ auto x4 = rewriter.create <tosa::MulOp>(loc, outType, x3, absX, mulShift );
3519+ auto a4X = rewriter.create <tosa::MulOp>(loc, outType, a4, x4, mulShift );
35133520 sum = rewriter.create <tosa::AddOp>(loc, outType, sum, a4X);
35143521
35153522 auto rcprl = rewriter.create <tosa::ReciprocalOp>(loc, outType, sum);
35163523 auto rcprl2 =
3517- rewriter.create <tosa::MulOp>(loc, outType, rcprl, rcprl, /* shift= */ 0 );
3524+ rewriter.create <tosa::MulOp>(loc, outType, rcprl, rcprl, mulShift );
35183525 auto rcprl4 =
3519- rewriter.create <tosa::MulOp>(loc, outType, rcprl2, rcprl2, /* shift= */ 0 );
3526+ rewriter.create <tosa::MulOp>(loc, outType, rcprl2, rcprl2, mulShift );
35203527 auto erf = rewriter.create <tosa::SubOp>(loc, outType, one, rcprl4);
35213528
35223529 // Deal with negative x.
@@ -3553,13 +3560,14 @@ buildUnitNormalCdf(ConversionPatternRewriter &rewriter, Operation *op, Value x,
35533560 auto mean = zero;
35543561 Value xMinusMean = rewriter.create <tosa::SubOp>(loc, outType, x, mean);
35553562
3556- Value erfArg = rewriter.create <tosa::MulOp>(loc, outType, xMinusMean, rsqrt2,
3557- /* shift=*/ 0 );
3563+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
3564+ Value erfArg =
3565+ rewriter.create <tosa::MulOp>(loc, outType, xMinusMean, rsqrt2, mulShift);
35583566 Value erf = approximateErfOp (rewriter, op, erfArg, dtype).value ();
35593567 Value erfPlus1 = rewriter.create <tosa::AddOp>(loc, outType, one, erf);
35603568
3561- Value normalCdf = rewriter. create <tosa::MulOp>(loc, outType, oneHalf,
3562- erfPlus1, /* shift= */ 0 );
3569+ Value normalCdf =
3570+ rewriter. create <tosa::MulOp>(loc, outType, oneHalf, erfPlus1, mulShift );
35633571 return normalCdf;
35643572}
35653573
@@ -3599,8 +3607,9 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
35993607 op->getLoc (),
36003608 cast<RankedTensorType>(cdf.getType ()).cloneWith ({}, selfElemTy), cdf);
36013609
3610+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
36023611 rewriter.replaceOpWithNewOp <tosa::MulOp>(op, resultType, self, cdf,
3603- /* shift= */ 0 );
3612+ mulShift );
36043613 } else if (approximate.compare (" tanh" ) == 0 ) {
36053614 // "tanh" approximate
36063615 // GELU(x) = 0.5 * x * (1 + Tanh(sqrt(2/pi) * (x + 0.044715 * x^3))
@@ -3644,8 +3653,9 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
36443653 .value ();
36453654
36463655 // 0.5 * x
3656+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
36473657 auto halfInput = rewriter.create <tosa::MulOp>(op->getLoc (), resultType,
3648- half, self, /* shift= */ 0 );
3658+ half, self, mulShift );
36493659
36503660 // sqrt(2/pi)
36513661 auto sqrtTwoOverPi =
@@ -3658,7 +3668,7 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
36583668 // 0.044715 * x^3
36593669 auto inputPowThreeMul =
36603670 rewriter.create <tosa::MulOp>(op->getLoc (), resultType, magicNumber,
3661- inputPowThree.getResult (), /* shift= */ 0 );
3671+ inputPowThree.getResult (), mulShift );
36623672
36633673 // x + 0.044715 * x^3
36643674 auto inputPowThreeMulAdd = rewriter.create <tosa::AddOp>(
@@ -3667,7 +3677,7 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
36673677 // sqrt(2/pi) * (x + 0.044715 * x^3)
36683678 auto sqrtTwoOverPiMul = rewriter.create <tosa::MulOp>(
36693679 op->getLoc (), resultType, sqrtTwoOverPi.getResult (),
3670- inputPowThreeMulAdd.getResult (), /* shift= */ 0 );
3680+ inputPowThreeMulAdd.getResult (), mulShift );
36713681
36723682 // tanh(sqrt(2/pi) * (x + 0.044715 * x^3))
36733683 auto tanh = rewriter.create <tosa::TanhOp>(op->getLoc (), resultType,
@@ -3678,8 +3688,7 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
36783688 tanh.getResult ());
36793689
36803690 rewriter.replaceOpWithNewOp <tosa::MulOp>(
3681- op, resultType, halfInput.getResult (), tanhAdd.getResult (),
3682- /* shift=*/ 0 );
3691+ op, resultType, halfInput.getResult (), tanhAdd.getResult (), mulShift);
36833692 } else {
36843693 return rewriter.notifyMatchFailure (op,
36853694 " Unsupported approximation algorithm" );
@@ -3732,23 +3741,23 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
37323741 return rewriter.notifyMatchFailure (
37333742 op, " Failed to equalize ranks among operands and result" );
37343743
3744+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
37353745 Value inputSquared =
3736- rewriter.create <tosa::MulOp>(loc, selfType, self, self, /* shift= */ 0 );
3746+ rewriter.create <tosa::MulOp>(loc, selfType, self, self, mulShift );
37373747 Value negHalfInputSquared = rewriter.create <tosa::MulOp>(
3738- loc, selfType, inputSquared, negOneHalf, /* shift= */ 0 );
3748+ loc, selfType, inputSquared, negOneHalf, mulShift );
37393749 Value dinput =
37403750 rewriter.create <tosa::ExpOp>(loc, selfType, negHalfInputSquared);
37413751 Value cdf = buildUnitNormalCdf (rewriter, op, self, selfElemTy).value ();
37423752 Value dinputInput =
3743- rewriter.create <tosa::MulOp>(loc, selfType, dinput, self, /* shift= */ 0 );
3753+ rewriter.create <tosa::MulOp>(loc, selfType, dinput, self, mulShift );
37443754 Value dinputInputAlpha = rewriter.create <tosa::MulOp>(
3745- loc, selfType, dinputInput, kAlphaHalf , /* shift= */ 0 );
3755+ loc, selfType, dinputInput, kAlphaHalf , mulShift );
37463756 Value cdfExt =
37473757 rewriter.create <tosa::AddOp>(loc, selfType, dinputInputAlpha, cdf);
37483758 rewriter.replaceOpWithNewOp <tosa::MulOp>(
37493759 op, getTypeConverter ()->convertType (op.getType ()),
3750- adaptor.getGradOutput (), cdfExt,
3751- /* shift=*/ 0 );
3760+ adaptor.getGradOutput (), cdfExt, mulShift);
37523761
37533762 return success ();
37543763}
@@ -5232,8 +5241,9 @@ LogicalResult ConvertAtenOp<AtenIscloseOp>::matchAndRewrite(
52325241 rewriter.create <tosa::AbsOp>(op->getLoc (), selfType, rhsSubOp);
52335242
52345243 auto lhsAbsOp = rewriter.create <tosa::AbsOp>(op->getLoc (), otherType, other);
5244+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
52355245 auto mulOp = rewriter.create <tosa::MulOp>(op->getLoc (), otherType,
5236- rtolConstOp, lhsAbsOp, /* shift= */ 0 );
5246+ rtolConstOp, lhsAbsOp, mulShift );
52375247 auto addOp =
52385248 rewriter.create <tosa::AddOp>(op->getLoc (), otherType, atolConstOp, mulOp);
52395249
@@ -5778,8 +5788,9 @@ class ConvertAtenRemainderFmodOp : public OpConversionPattern<AtenOpT> {
57785788 if (isa<mlir::FloatType>(outElemTy)) {
57795789 auto otherTensorReciprocal = rewriter.create <tosa::ReciprocalOp>(
57805790 op.getLoc (), otherTensor.getType (), otherTensor);
5791+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
57815792 divTensor = rewriter.create <tosa::MulOp>(
5782- op.getLoc (), outType, self, otherTensorReciprocal, /* shift= */ 0 );
5793+ op.getLoc (), outType, self, otherTensorReciprocal, mulShift );
57835794 divTensor =
57845795 rewriter.create <tosa::FloorOp>(op.getLoc (), outType, divTensor);
57855796 } else {
@@ -5804,9 +5815,9 @@ class ConvertAtenRemainderFmodOp : public OpConversionPattern<AtenOpT> {
58045815 }
58055816 }
58065817
5807- auto mulTensor = rewriter. create < tosa::MulOp>(op. getLoc (), outType,
5808- otherTensor, divTensor,
5809- /* shift= */ 0 );
5818+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
5819+ auto mulTensor = rewriter. create <tosa::MulOp>(
5820+ op. getLoc (), outType, otherTensor, divTensor, mulShift );
58105821 rewriter.replaceOpWithNewOp <tosa::SubOp>(op, outType, self, mulTensor);
58115822
58125823 return success ();
@@ -7010,8 +7021,9 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
70107021 return rewriter.notifyMatchFailure (
70117022 op, " Failed to equalize ranks among operands and result" );
70127023
7024+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
70137025 rewriter.replaceOpWithNewOp <tosa::MulOp>(op, resultType, self, trilMask,
7014- /* shift= */ 0 );
7026+ mulShift );
70157027
70167028 return success ();
70177029}
@@ -7106,15 +7118,16 @@ LogicalResult ConvertAtenOp<AtenRoundOp>::matchAndRewrite(
71067118
71077119 auto ceilInput = rewriter.create <tosa::CeilOp>(op->getLoc (), resultTy, self);
71087120
7121+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
71097122 auto floorInputDivByTwo = rewriter.create <tosa::MulOp>(
7110- op->getLoc (), resultTy, floorInput.getResult (), oneHalf, /* shift= */ 0 );
7123+ op->getLoc (), resultTy, floorInput.getResult (), oneHalf, mulShift );
71117124
71127125 auto floorDivResult = rewriter.create <tosa::FloorOp>(
71137126 op->getLoc (), resultTy, floorInputDivByTwo.getResult ());
71147127
71157128 // (floor(input) // 2) * 2
71167129 auto evenComparison = rewriter.create <tosa::MulOp>(
7117- op->getLoc (), resultTy, floorDivResult.getResult (), two, /* shift= */ 0 );
7130+ op->getLoc (), resultTy, floorDivResult.getResult (), two, mulShift );
71187131
71197132 // floor(input) // 2) * 2 == input <=> floor(input) % 2 == 0
71207133 auto floorInputEven = rewriter.create <tosa::EqualOp>(
@@ -7296,9 +7309,10 @@ LogicalResult ConvertAtenOp<AtenDiagonalOp>::matchAndRewrite(
72967309 return rewriter.notifyMatchFailure (
72977310 op, " Failed to equalize ranks among operands and result" );
72987311
7299- Value diagonalTensor = rewriter.create <tosa::MulOp>(
7300- op->getLoc (), transposedInputType, selfTransposed, diagonalMask,
7301- /* shift=*/ 0 );
7312+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
7313+ Value diagonalTensor =
7314+ rewriter.create <tosa::MulOp>(op->getLoc (), transposedInputType,
7315+ selfTransposed, diagonalMask, mulShift);
73027316
73037317 auto resultShape = makeShapeTorchCompatible (resultType.getShape ());
73047318 auto targetReduceDim = resultShape[resultType.getRank () - 1 ];
@@ -8587,9 +8601,9 @@ LogicalResult ConvertAtenOp<AtenLogitOp>::matchAndRewrite(
85878601 auto oneMinusZiReciprocal = rewriter.create <tosa::ReciprocalOp>(
85888602 op->getLoc (), resultType, oneMinusZi.getResult ());
85898603
8590- auto mulOp = rewriter. create < tosa::MulOp>(op-> getLoc (), resultType, zi,
8591- oneMinusZiReciprocal. getResult (),
8592- /* shift= */ 0 );
8604+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
8605+ auto mulOp = rewriter. create <tosa::MulOp>(
8606+ op-> getLoc (), resultType, zi, oneMinusZiReciprocal. getResult (), mulShift );
85938607
85948608 auto result =
85958609 rewriter.create <tosa::LogOp>(op->getLoc (), resultType, mulOp.getResult ());
@@ -8687,9 +8701,10 @@ LogicalResult ConvertAtenOp<AtenLog10Op>::matchAndRewrite(
86878701 auto reciprocalOp = rewriter.create <tosa::ReciprocalOp>(
86888702 op->getLoc (), constTenType, logOfTen.getResult ());
86898703
8704+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
86908705 auto result = rewriter.create <tosa::MulOp>(
86918706 op->getLoc (), resultType, logOfSelf.getResult (), reciprocalOp.getResult (),
8692- /* shift= */ 0 );
8707+ mulShift );
86938708
86948709 rewriter.replaceOp (op, {result.getResult ()});
86958710
@@ -8772,9 +8787,10 @@ LogicalResult ConvertAtenOp<AtenTanOp>::matchAndRewrite(
87728787 auto reciprocalOp =
87738788 rewriter.create <tosa::ReciprocalOp>(op->getLoc (), resultType, cosOp);
87748789
8775- auto result = rewriter.create <tosa::MulOp>(
8776- op->getLoc (), resultType, sinOp.getResult (), reciprocalOp.getResult (),
8777- /* shift=*/ 0 );
8790+ auto mulShift = tosa::getTosaMulShiftConstTensor (rewriter, op, 0 );
8791+ auto result =
8792+ rewriter.create <tosa::MulOp>(op->getLoc (), resultType, sinOp.getResult (),
8793+ reciprocalOp.getResult (), mulShift);
87788794
87798795 rewriter.replaceOp (op, {result.getResult ()});
87808796
0 commit comments