1010#include " AArch64ExpandImm.h"
1111#include " AArch64PerfectShuffle.h"
1212#include " MCTargetDesc/AArch64AddressingModes.h"
13+ #include " llvm/ADT/DenseMap.h"
1314#include " llvm/Analysis/IVDescriptors.h"
1415#include " llvm/Analysis/LoopInfo.h"
1516#include " llvm/Analysis/TargetTransformInfo.h"
@@ -3145,10 +3146,10 @@ InstructionCost AArch64TTIImpl::getCFInstrCost(unsigned Opcode,
31453146 return 0 ;
31463147}
31473148
3148- InstructionCost AArch64TTIImpl::getVectorInstrCostHelper (const Instruction *I,
3149- Type *Val ,
3150- unsigned Index ,
3151- bool HasRealUse ) {
3149+ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper (
3150+ Type *Val, unsigned Index, bool HasRealUse, const Instruction *I ,
3151+ std::optional< unsigned > Opcode, Value *Scalar ,
3152+ ArrayRef<std::tuple<Value *, User *, int >> ScalarUserAndIdx ) {
31523153 assert (Val->isVectorTy () && " This must be a vector type" );
31533154
31543155 if (Index != -1U ) {
@@ -3194,6 +3195,138 @@ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(const Instruction *I,
31943195 // compile-time considerations.
31953196 }
31963197
3198+ // In case of Neon, if there exists extractelement from lane != 0 such that
3199+ // 1. extractelement does not necessitate a move from vector_reg -> GPR.
3200+ // 2. extractelement result feeds into fmul.
3201+ // 3. Other operand of fmul is an extractelement from lane 0 or lane
3202+ // equivalent to 0.
3203+ // then the extractelement can be merged with fmul in the backend and it
3204+ // incurs no cost.
3205+ // e.g.
3206+ // define double @foo(<2 x double> %a) {
3207+ // %1 = extractelement <2 x double> %a, i32 0
3208+ // %2 = extractelement <2 x double> %a, i32 1
3209+ // %res = fmul double %1, %2
3210+ // ret double %res
3211+ // }
3212+ // %2 and %res can be merged in the backend to generate fmul d0, d0, v1.d[1]
3213+ auto ExtractCanFuseWithFmul = [&]() {
3214+ // We bail out if the extract is from lane 0.
3215+ if (Index == 0 )
3216+ return false ;
3217+
3218+ // Check if the scalar element type of the vector operand of ExtractElement
3219+ // instruction is one of the allowed types.
3220+ auto IsAllowedScalarTy = [&](const Type *T) {
3221+ return T->isFloatTy () || T->isDoubleTy () ||
3222+ (T->isHalfTy () && ST->hasFullFP16 ());
3223+ };
3224+
3225+ // Check if the extractelement user is scalar fmul.
3226+ auto IsUserFMulScalarTy = [](const Value *EEUser) {
3227+ // Check if the user is scalar fmul.
3228+ const auto *BO = dyn_cast_if_present<BinaryOperator>(EEUser);
3229+ return BO && BO->getOpcode () == BinaryOperator::FMul &&
3230+ !BO->getType ()->isVectorTy ();
3231+ };
3232+
3233+ // InstCombine combines fmul with fadd/fsub. Hence, extractelement fusion
3234+ // with fmul does not happen.
3235+ auto IsFMulUserFAddFSub = [](const Value *FMul) {
3236+ return any_of (FMul->users (), [](const User *U) {
3237+ const auto *BO = dyn_cast_if_present<BinaryOperator>(U);
3238+ return (BO && (BO->getOpcode () == BinaryOperator::FAdd ||
3239+ BO->getOpcode () == BinaryOperator::FSub));
3240+ });
3241+ };
3242+
3243+ // Check if the type constraints on input vector type and result scalar type
3244+ // of extractelement instruction are satisfied.
3245+ auto TypeConstraintsOnEESatisfied =
3246+ [&IsAllowedScalarTy](const Type *VectorTy, const Type *ScalarTy) {
3247+ return isa<FixedVectorType>(VectorTy) && IsAllowedScalarTy (ScalarTy);
3248+ };
3249+
3250+ // Check if the extract index is from lane 0 or lane equivalent to 0 for a
3251+ // certain scalar type and a certain vector register width.
3252+ auto IsExtractLaneEquivalentToZero = [&](const unsigned &Idx,
3253+ const unsigned &EltSz) {
3254+ auto RegWidth =
3255+ getRegisterBitWidth (TargetTransformInfo::RGK_FixedWidthVector)
3256+ .getFixedValue ();
3257+ return (Idx == 0 || (Idx * EltSz) % RegWidth == 0 );
3258+ };
3259+
3260+ if (Opcode.has_value ()) {
3261+ if (!TypeConstraintsOnEESatisfied (Val, Val->getScalarType ()))
3262+ return false ;
3263+
3264+ DenseMap<User *, unsigned > UserToExtractIdx;
3265+ for (auto *U : Scalar->users ()) {
3266+ if (!IsUserFMulScalarTy (U) || IsFMulUserFAddFSub (U))
3267+ return false ;
3268+ // Recording entry for the user is important. Index value is not
3269+ // important.
3270+ UserToExtractIdx[U];
3271+ }
3272+ for (auto &[S, U, L] : ScalarUserAndIdx) {
3273+ for (auto *U : S->users ()) {
3274+ if (UserToExtractIdx.find (U) != UserToExtractIdx.end ()) {
3275+ auto *FMul = cast<BinaryOperator>(U);
3276+ auto *Op0 = FMul->getOperand (0 );
3277+ auto *Op1 = FMul->getOperand (1 );
3278+ if ((Op0 == S && Op1 == S) || (Op0 != S) || (Op1 != S)) {
3279+ UserToExtractIdx[U] = L;
3280+ break ;
3281+ }
3282+ }
3283+ }
3284+ }
3285+ for (auto &[U, L] : UserToExtractIdx) {
3286+ if (!IsExtractLaneEquivalentToZero (Index, Val->getScalarSizeInBits ()) &&
3287+ !IsExtractLaneEquivalentToZero (L, Val->getScalarSizeInBits ()))
3288+ return false ;
3289+ }
3290+ } else {
3291+ const auto *EE = cast<ExtractElementInst>(I);
3292+
3293+ const auto *IdxOp = dyn_cast<ConstantInt>(EE->getIndexOperand ());
3294+ if (!IdxOp)
3295+ return false ;
3296+
3297+ if (!TypeConstraintsOnEESatisfied (EE->getVectorOperand ()->getType (),
3298+ EE->getType ()))
3299+ return false ;
3300+
3301+ return !EE->users ().empty () && all_of (EE->users (), [&](const User *U) {
3302+ if (!IsUserFMulScalarTy (U) || IsFMulUserFAddFSub (U))
3303+ return false ;
3304+
3305+ // Check if the other operand of extractelement is also extractelement
3306+ // from lane equivalent to 0.
3307+ const auto *BO = cast<BinaryOperator>(U);
3308+ const auto *OtherEE = dyn_cast<ExtractElementInst>(
3309+ BO->getOperand (0 ) == EE ? BO->getOperand (1 ) : BO->getOperand (0 ));
3310+ if (OtherEE) {
3311+ const auto *IdxOp = dyn_cast<ConstantInt>(OtherEE->getIndexOperand ());
3312+ if (!IdxOp)
3313+ return false ;
3314+ return IsExtractLaneEquivalentToZero (
3315+ cast<ConstantInt>(OtherEE->getIndexOperand ())
3316+ ->getValue ()
3317+ .getZExtValue (),
3318+ OtherEE->getType ()->getScalarSizeInBits ());
3319+ }
3320+ return true ;
3321+ });
3322+ }
3323+ return true ;
3324+ };
3325+
3326+ unsigned InstOpcode = I ? I->getOpcode () : Opcode.value ();
3327+ if (InstOpcode == Instruction::ExtractElement && ExtractCanFuseWithFmul ())
3328+ return 0 ;
3329+
31973330 // All other insert/extracts cost this much.
31983331 return ST->getVectorInsertExtractBaseCost ();
31993332}
@@ -3204,14 +3337,22 @@ InstructionCost AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
32043337 Value *Op1) {
32053338 bool HasRealUse =
32063339 Opcode == Instruction::InsertElement && Op0 && !isa<UndefValue>(Op0);
3207- return getVectorInstrCostHelper (nullptr , Val, Index, HasRealUse);
3340+ return getVectorInstrCostHelper (Val, Index, HasRealUse);
3341+ }
3342+
3343+ InstructionCost AArch64TTIImpl::getVectorInstrCost (
3344+ unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
3345+ Value *Scalar,
3346+ ArrayRef<std::tuple<Value *, User *, int >> ScalarUserAndIdx) {
3347+ return getVectorInstrCostHelper (Val, Index, false , nullptr , Opcode, Scalar,
3348+ ScalarUserAndIdx);
32083349}
32093350
32103351InstructionCost AArch64TTIImpl::getVectorInstrCost (const Instruction &I,
32113352 Type *Val,
32123353 TTI::TargetCostKind CostKind,
32133354 unsigned Index) {
3214- return getVectorInstrCostHelper (&I, Val, Index, true /* HasRealUse */ );
3355+ return getVectorInstrCostHelper (Val, Index, true /* HasRealUse */ , &I );
32153356}
32163357
32173358InstructionCost AArch64TTIImpl::getScalarizationOverhead (
0 commit comments