Skip to content

Commit 2191269

Browse files
committed
Adjust for LLVM bump_to_a58e774f
1 parent 89e85e1 commit 2191269

File tree

5 files changed

+105
-58
lines changed

5 files changed

+105
-58
lines changed

include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ bool isScale32(mlir::quant::UniformQuantizedType output_element_type);
4949
Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
5050
float val);
5151

52+
// Create an int8_t const tosa.mul shift tensor from an int
53+
Value getTosaMulShiftConstTensor(PatternRewriter &rewriter, Operation *op,
54+
int32_t shift);
55+
5256
// Create a zero constant tensor of the desired type and shape.
5357
std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
5458
Operation *op, Type type);

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 80 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -672,14 +672,15 @@ std::optional<Value> floorIntDiv(PatternRewriter &rewriter, Operation *op,
672672
auto boolType =
673673
RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(1));
674674

675-
auto lhsMulRhs = rewriter.create<tosa::MulOp>(op->getLoc(), i32Type, lhs, rhs,
676-
/*shift=*/0);
675+
auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, 0);
676+
auto lhsMulRhs =
677+
rewriter.create<tosa::MulOp>(op->getLoc(), i32Type, lhs, rhs, mulShift);
677678

678679
auto lhsRhsDifferentSign =
679680
rewriter.create<tosa::GreaterOp>(op->getLoc(), boolType, zero, lhsMulRhs);
680681

681682
auto truncMulRhs = rewriter.create<tosa::MulOp>(op->getLoc(), i32Type,
682-
intDivOp, rhs, /*shift=*/0);
683+
intDivOp, rhs, mulShift);
683684

684685
auto truncMulRhsEqualLhs =
685686
rewriter.create<tosa::EqualOp>(op->getLoc(), boolType, truncMulRhs, lhs);
@@ -918,9 +919,10 @@ LogicalResult ConvertAtenOp<AtenLeakyReluOp>::matchAndRewrite(
918919
op->getLoc(),
919920
RankedTensorType::get(selfTy.getShape(), rewriter.getIntegerType(1)),
920921
self, zero);
922+
auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, 0);
921923
auto mulTensor = rewriter.create<tosa::MulOp>(
922924
op->getLoc(), getTypeConverter()->convertType(op.getType()), self,
923-
alphaTensor, /*shift=*/0);
925+
alphaTensor, mulShift);
924926

925927
rewriter.replaceOpWithNewOp<tosa::SelectOp>(
926928
op, getTypeConverter()->convertType(op.getType()), cond, self, mulTensor);
@@ -2348,8 +2350,10 @@ LogicalResult ConvertAtenOp<AtenRsubScalarOp>::matchAndRewrite(
23482350
return rewriter.notifyMatchFailure(
23492351
op, "Failed to equalize ranks among operands and result");
23502352

2351-
auto multTensor = rewriter.create<tosa::MulOp>(op->getLoc(), resultTy, self,
2352-
alphaTensor, /*shift=*/0);
2353+
auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, 0);
2354+
auto multTensor =
2355+
rewriter.create<tosa::MulOp>(op->getLoc(), resultTy, self, alphaTensor,
2356+
mulShift);
23532357

23542358
rewriter.replaceOpWithNewOp<tosa::SubOp>(op, resultTy, otherTensor,
23552359
multTensor);
@@ -2761,12 +2765,14 @@ std::optional<Value> computeBatchNorm(Operation *op,
27612765
auto op3RsqrtOp2 = rewriter.create<tosa::RsqrtOp>(
27622766
op->getLoc(), variance.getType(), op2AddVarEpsilon.getResult());
27632767

2764-
auto op4MulOp1Op3 = rewriter.create<tosa::MulOp>(op->getLoc(), outType,
2765-
op1SubInputMean.getResult(),
2766-
op3RsqrtOp2.getResult(), 0);
2768+
auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, 0);
2769+
auto op4MulOp1Op3 =
2770+
rewriter.create<tosa::MulOp>(op->getLoc(), outType,
2771+
op1SubInputMean.getResult(),
2772+
op3RsqrtOp2.getResult(), mulShift);
27672773

27682774
auto op5MulOp4Scale = rewriter.create<tosa::MulOp>(
2769-
op->getLoc(), outType, op4MulOp1Op3.getResult(), weight, 0);
2775+
op->getLoc(), outType, op4MulOp1Op3.getResult(), weight, mulShift);
27702776

27712777
return rewriter
27722778
.create<tosa::AddOp>(op->getLoc(), outType, op5MulOp4Scale.getResult(),
@@ -2986,22 +2992,24 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
29862992
auto bcastOutType =
29872993
RankedTensorType::get(makeShapeLLVMCompatible(bcastOutShape), elemTy);
29882994

2995+
auto mulShift =
2996+
tosa::getTosaMulShiftConstTensor(rewriter, op.getOperation(), 0);
29892997
// Compute mean.
29902998
Value sum =
29912999
computeSumAndReshape(input, inputType, bcastOutType, bcastOutShape);
29923000
Value meanVal = rewriter.create<tosa::MulOp>(op.getLoc(), bcastOutType, sum,
2993-
elemCntRcp, /*shift=*/0);
3001+
elemCntRcp, mulShift);
29943002

29953003
// Compute variance.
29963004
Value squareSumSub =
29973005
rewriter.create<tosa::SubOp>(op.getLoc(), inputType, input, meanVal);
2998-
Value squareSum = rewriter.create<tosa::MulOp>(op.getLoc(), inputType,
2999-
squareSumSub, squareSumSub, 0);
3006+
Value squareSum = rewriter.create<tosa::MulOp>(
3007+
op.getLoc(), inputType, squareSumSub, squareSumSub, mulShift);
30003008

30013009
Value squareSumReduced =
30023010
computeSumAndReshape(squareSum, inputType, bcastOutType, bcastOutShape);
30033011
Value varianceVal = rewriter.create<tosa::MulOp>(
3004-
op.getLoc(), bcastOutType, squareSumReduced, elemCntRcp, /*shift=*/0);
3012+
op.getLoc(), bcastOutType, squareSumReduced, elemCntRcp, mulShift);
30053013

30063014
// Reshape weight and bias.
30073015
SmallVector<int64_t> weightAndBiasBcastShape;
@@ -3259,8 +3267,9 @@ LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
32593267
rewriter.create<tosa::ReciprocalOp>(op.getLoc(), ln2Op.getType(), ln2Op);
32603268

32613269
auto logOp = rewriter.create<tosa::LogOp>(op.getLoc(), outType, self);
3270+
auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, 0);
32623271
rewriter.replaceOpWithNewOp<tosa::MulOp>(op, outType, logOp, rcpOp,
3263-
/*shift=*/0);
3272+
mulShift);
32643273

32653274
return success();
32663275
}
@@ -3497,26 +3506,31 @@ approximateErfOp(ConversionPatternRewriter &rewriter, Operation *op, Value x,
34973506
mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, a4).failed())
34983507
return std::nullopt;
34993508

3500-
auto a1X = rewriter.create<tosa::MulOp>(loc, outType, a1, absX, /*shift=*/0);
3509+
auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, 0);
3510+
auto a1X =
3511+
rewriter.create<tosa::MulOp>(loc, outType, a1, absX, mulShift);
35013512
auto sum = rewriter.create<tosa::AddOp>(loc, outType, a1X, one);
35023513

3503-
auto x2 = rewriter.create<tosa::MulOp>(loc, outType, absX, absX, /*shift=*/0);
3504-
auto a2X = rewriter.create<tosa::MulOp>(loc, outType, a2, x2, /*shift=*/0);
3514+
auto x2 =
3515+
rewriter.create<tosa::MulOp>(loc, outType, absX, absX, mulShift);
3516+
auto a2X = rewriter.create<tosa::MulOp>(loc, outType, a2, x2, mulShift);
35053517
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a2X);
35063518

3507-
auto x3 = rewriter.create<tosa::MulOp>(loc, outType, x2, absX, /*shift=*/0);
3508-
auto a3X = rewriter.create<tosa::MulOp>(loc, outType, a3, x3, /*shift=*/0);
3519+
auto x3 =
3520+
rewriter.create<tosa::MulOp>(loc, outType, x2, absX, mulShift);
3521+
auto a3X = rewriter.create<tosa::MulOp>(loc, outType, a3, x3, mulShift);
35093522
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a3X);
35103523

3511-
auto x4 = rewriter.create<tosa::MulOp>(loc, outType, x3, absX, /*shift=*/0);
3512-
auto a4X = rewriter.create<tosa::MulOp>(loc, outType, a4, x4, /*shift=*/0);
3524+
auto x4 =
3525+
rewriter.create<tosa::MulOp>(loc, outType, x3, absX, mulShift);
3526+
auto a4X = rewriter.create<tosa::MulOp>(loc, outType, a4, x4, mulShift);
35133527
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a4X);
35143528

35153529
auto rcprl = rewriter.create<tosa::ReciprocalOp>(loc, outType, sum);
35163530
auto rcprl2 =
3517-
rewriter.create<tosa::MulOp>(loc, outType, rcprl, rcprl, /*shift=*/0);
3531+
rewriter.create<tosa::MulOp>(loc, outType, rcprl, rcprl, mulShift);
35183532
auto rcprl4 =
3519-
rewriter.create<tosa::MulOp>(loc, outType, rcprl2, rcprl2, /*shift=*/0);
3533+
rewriter.create<tosa::MulOp>(loc, outType, rcprl2, rcprl2, mulShift);
35203534
auto erf = rewriter.create<tosa::SubOp>(loc, outType, one, rcprl4);
35213535

35223536
// Deal with negative x.
@@ -3553,13 +3567,14 @@ buildUnitNormalCdf(ConversionPatternRewriter &rewriter, Operation *op, Value x,
35533567
auto mean = zero;
35543568
Value xMinusMean = rewriter.create<tosa::SubOp>(loc, outType, x, mean);
35553569

3556-
Value erfArg = rewriter.create<tosa::MulOp>(loc, outType, xMinusMean, rsqrt2,
3557-
/*shift=*/0);
3570+
auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, 0);
3571+
Value erfArg =
3572+
rewriter.create<tosa::MulOp>(loc, outType, xMinusMean, rsqrt2, mulShift);
35583573
Value erf = approximateErfOp(rewriter, op, erfArg, dtype).value();
35593574
Value erfPlus1 = rewriter.create<tosa::AddOp>(loc, outType, one, erf);
35603575

3561-
Value normalCdf = rewriter.create<tosa::MulOp>(loc, outType, oneHalf,
3562-
erfPlus1, /*shift=*/0);
3576+
Value normalCdf =
3577+
rewriter.create<tosa::MulOp>(loc, outType, oneHalf, erfPlus1, mulShift);
35633578
return normalCdf;
35643579
}
35653580

@@ -3599,8 +3614,9 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
35993614
op->getLoc(),
36003615
cast<RankedTensorType>(cdf.getType()).cloneWith({}, selfElemTy), cdf);
36013616

3617+
auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, 0);
36023618
rewriter.replaceOpWithNewOp<tosa::MulOp>(op, resultType, self, cdf,
3603-
/*shift=*/0);
3619+
mulShift);
36043620
} else if (approximate.compare("tanh") == 0) {
36053621
// "tanh" approximate
36063622
// GELU(x) = 0.5 * x * (1 + Tanh(sqrt(2/pi) * (x + 0.044715 * x^3))
@@ -3644,8 +3660,9 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
36443660
.value();
36453661

36463662
// 0.5 * x
3663+
auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, 0);
36473664
auto halfInput = rewriter.create<tosa::MulOp>(op->getLoc(), resultType,
3648-
half, self, /*shift=*/0);
3665+
half, self, mulShift);
36493666

36503667
// sqrt(2/pi)
36513668
auto sqrtTwoOverPi =
@@ -3658,7 +3675,7 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
36583675
// 0.044715 * x^3
36593676
auto inputPowThreeMul =
36603677
rewriter.create<tosa::MulOp>(op->getLoc(), resultType, magicNumber,
3661-
inputPowThree.getResult(), /*shift=*/0);
3678+
inputPowThree.getResult(), mulShift);
36623679

36633680
// x + 0.044715 * x^3
36643681
auto inputPowThreeMulAdd = rewriter.create<tosa::AddOp>(
@@ -3667,7 +3684,7 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
36673684
// sqrt(2/pi) * (x + 0.044715 * x^3)
36683685
auto sqrtTwoOverPiMul = rewriter.create<tosa::MulOp>(
36693686
op->getLoc(), resultType, sqrtTwoOverPi.getResult(),
3670-
inputPowThreeMulAdd.getResult(), /*shift=*/0);
3687+
inputPowThreeMulAdd.getResult(), mulShift);
36713688

36723689
// tanh(sqrt(2/pi) * (x + 0.044715 * x^3))
36733690
auto tanh = rewriter.create<tosa::TanhOp>(op->getLoc(), resultType,
@@ -3679,7 +3696,7 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
36793696

36803697
rewriter.replaceOpWithNewOp<tosa::MulOp>(
36813698
op, resultType, halfInput.getResult(), tanhAdd.getResult(),
3682-
/*shift=*/0);
3699+
mulShift);
36833700
} else {
36843701
return rewriter.notifyMatchFailure(op,
36853702
"Unsupported approximation algorithm");
@@ -3732,23 +3749,23 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
37323749
return rewriter.notifyMatchFailure(
37333750
op, "Failed to equalize ranks among operands and result");
37343751

3752+
auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, 0);
37353753
Value inputSquared =
3736-
rewriter.create<tosa::MulOp>(loc, selfType, self, self, /*shift=*/0);
3754+
rewriter.create<tosa::MulOp>(loc, selfType, self, self, mulShift);
37373755
Value negHalfInputSquared = rewriter.create<tosa::MulOp>(
3738-
loc, selfType, inputSquared, negOneHalf, /*shift=*/0);
3756+
loc, selfType, inputSquared, negOneHalf, mulShift);
37393757
Value dinput =
37403758
rewriter.create<tosa::ExpOp>(loc, selfType, negHalfInputSquared);
37413759
Value cdf = buildUnitNormalCdf(rewriter, op, self, selfElemTy).value();
37423760
Value dinputInput =
3743-
rewriter.create<tosa::MulOp>(loc, selfType, dinput, self, /*shift=*/0);
3761+
rewriter.create<tosa::MulOp>(loc, selfType, dinput, self, mulShift);
37443762
Value dinputInputAlpha = rewriter.create<tosa::MulOp>(
3745-
loc, selfType, dinputInput, kAlphaHalf, /*shift=*/0);
3763+
loc, selfType, dinputInput, kAlphaHalf, mulShift);
37463764
Value cdfExt =
37473765
rewriter.create<tosa::AddOp>(loc, selfType, dinputInputAlpha, cdf);
37483766
rewriter.replaceOpWithNewOp<tosa::MulOp>(
37493767
op, getTypeConverter()->convertType(op.getType()),
3750-
adaptor.getGradOutput(), cdfExt,
3751-
/*shift=*/0);
3768+
adaptor.getGradOutput(), cdfExt, mulShift);
37523769

37533770
return success();
37543771
}
@@ -5232,8 +5249,9 @@ LogicalResult ConvertAtenOp<AtenIscloseOp>::matchAndRewrite(
52325249
rewriter.create<tosa::AbsOp>(op->getLoc(), selfType, rhsSubOp);
52335250

52345251
auto lhsAbsOp = rewriter.create<tosa::AbsOp>(op->getLoc(), otherType, other);
5252+
auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, 0);
52355253
auto mulOp = rewriter.create<tosa::MulOp>(op->getLoc(), otherType,
5236-
rtolConstOp, lhsAbsOp, /*shift=*/0);
5254+
rtolConstOp, lhsAbsOp, mulShift);
52375255
auto addOp =
52385256
rewriter.create<tosa::AddOp>(op->getLoc(), otherType, atolConstOp, mulOp);
52395257

@@ -5778,8 +5796,9 @@ class ConvertAtenRemainderFmodOp : public OpConversionPattern<AtenOpT> {
57785796
if (isa<mlir::FloatType>(outElemTy)) {
57795797
auto otherTensorReciprocal = rewriter.create<tosa::ReciprocalOp>(
57805798
op.getLoc(), otherTensor.getType(), otherTensor);
5799+
auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, 0);
57815800
divTensor = rewriter.create<tosa::MulOp>(
5782-
op.getLoc(), outType, self, otherTensorReciprocal, /*shift=*/0);
5801+
op.getLoc(), outType, self, otherTensorReciprocal, mulShift);
57835802
divTensor =
57845803
rewriter.create<tosa::FloorOp>(op.getLoc(), outType, divTensor);
57855804
} else {
@@ -5804,9 +5823,10 @@ class ConvertAtenRemainderFmodOp : public OpConversionPattern<AtenOpT> {
58045823
}
58055824
}
58065825

5807-
auto mulTensor = rewriter.create<tosa::MulOp>(op.getLoc(), outType,
5808-
otherTensor, divTensor,
5809-
/*shift=*/0);
5826+
auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, 0);
5827+
auto mulTensor =
5828+
rewriter.create<tosa::MulOp>(op.getLoc(), outType, otherTensor,
5829+
divTensor, mulShift);
58105830
rewriter.replaceOpWithNewOp<tosa::SubOp>(op, outType, self, mulTensor);
58115831

58125832
return success();
@@ -7010,8 +7030,9 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
70107030
return rewriter.notifyMatchFailure(
70117031
op, "Failed to equalize ranks among operands and result");
70127032

7033+
auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, 0);
70137034
rewriter.replaceOpWithNewOp<tosa::MulOp>(op, resultType, self, trilMask,
7014-
/*shift=*/0);
7035+
mulShift);
70157036

70167037
return success();
70177038
}
@@ -7106,15 +7127,16 @@ LogicalResult ConvertAtenOp<AtenRoundOp>::matchAndRewrite(
71067127

71077128
auto ceilInput = rewriter.create<tosa::CeilOp>(op->getLoc(), resultTy, self);
71087129

7130+
auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, 0);
71097131
auto floorInputDivByTwo = rewriter.create<tosa::MulOp>(
7110-
op->getLoc(), resultTy, floorInput.getResult(), oneHalf, /*shift=*/0);
7132+
op->getLoc(), resultTy, floorInput.getResult(), oneHalf, mulShift);
71117133

71127134
auto floorDivResult = rewriter.create<tosa::FloorOp>(
71137135
op->getLoc(), resultTy, floorInputDivByTwo.getResult());
71147136

71157137
// (floor(input) // 2) * 2
71167138
auto evenComparison = rewriter.create<tosa::MulOp>(
7117-
op->getLoc(), resultTy, floorDivResult.getResult(), two, /*shift=*/0);
7139+
op->getLoc(), resultTy, floorDivResult.getResult(), two, mulShift);
71187140

71197141
// floor(input) // 2) * 2 == input <=> floor(input) % 2 == 0
71207142
auto floorInputEven = rewriter.create<tosa::EqualOp>(
@@ -7296,9 +7318,10 @@ LogicalResult ConvertAtenOp<AtenDiagonalOp>::matchAndRewrite(
72967318
return rewriter.notifyMatchFailure(
72977319
op, "Failed to equalize ranks among operands and result");
72987320

7321+
auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, 0);
72997322
Value diagonalTensor = rewriter.create<tosa::MulOp>(
73007323
op->getLoc(), transposedInputType, selfTransposed, diagonalMask,
7301-
/*shift=*/0);
7324+
mulShift);
73027325

73037326
auto resultShape = makeShapeTorchCompatible(resultType.getShape());
73047327
auto targetReduceDim = resultShape[resultType.getRank() - 1];
@@ -8587,9 +8610,10 @@ LogicalResult ConvertAtenOp<AtenLogitOp>::matchAndRewrite(
85878610
auto oneMinusZiReciprocal = rewriter.create<tosa::ReciprocalOp>(
85888611
op->getLoc(), resultType, oneMinusZi.getResult());
85898612

8613+
auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, 0);
85908614
auto mulOp = rewriter.create<tosa::MulOp>(op->getLoc(), resultType, zi,
85918615
oneMinusZiReciprocal.getResult(),
8592-
/*shift=*/0);
8616+
mulShift);
85938617

85948618
auto result =
85958619
rewriter.create<tosa::LogOp>(op->getLoc(), resultType, mulOp.getResult());
@@ -8687,9 +8711,11 @@ LogicalResult ConvertAtenOp<AtenLog10Op>::matchAndRewrite(
86878711
auto reciprocalOp = rewriter.create<tosa::ReciprocalOp>(
86888712
op->getLoc(), constTenType, logOfTen.getResult());
86898713

8690-
auto result = rewriter.create<tosa::MulOp>(
8691-
op->getLoc(), resultType, logOfSelf.getResult(), reciprocalOp.getResult(),
8692-
/*shift=*/0);
8714+
auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, 0);
8715+
auto result =
8716+
rewriter.create<tosa::MulOp>(op->getLoc(), resultType,
8717+
logOfSelf.getResult(),
8718+
reciprocalOp.getResult(), mulShift);
86938719

86948720
rewriter.replaceOp(op, {result.getResult()});
86958721

@@ -8772,9 +8798,10 @@ LogicalResult ConvertAtenOp<AtenTanOp>::matchAndRewrite(
87728798
auto reciprocalOp =
87738799
rewriter.create<tosa::ReciprocalOp>(op->getLoc(), resultType, cosOp);
87748800

8801+
auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, 0);
87758802
auto result = rewriter.create<tosa::MulOp>(
87768803
op->getLoc(), resultType, sinOp.getResult(), reciprocalOp.getResult(),
8777-
/*shift=*/0);
8804+
mulShift);
87788805

87798806
rewriter.replaceOp(op, {result.getResult()});
87808807

0 commit comments

Comments
 (0)