1111#include " AArch64PerfectShuffle.h"
1212#include " MCTargetDesc/AArch64AddressingModes.h"
1313#include " Utils/AArch64SMEAttributes.h"
14+ #include " llvm/ADT/DenseMap.h"
1415#include " llvm/Analysis/IVDescriptors.h"
1516#include " llvm/Analysis/LoopInfo.h"
1617#include " llvm/Analysis/TargetTransformInfo.h"
@@ -3177,10 +3178,10 @@ InstructionCost AArch64TTIImpl::getCFInstrCost(unsigned Opcode,
31773178 return 0 ;
31783179}
31793180
3180- InstructionCost AArch64TTIImpl::getVectorInstrCostHelper (const Instruction *I,
3181- Type *Val ,
3182- unsigned Index ,
3183- bool HasRealUse ) {
3181+ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper (
3182+ Type *Val, unsigned Index, bool HasRealUse, const Instruction *I ,
3183+ std::optional< unsigned > Opcode, Value *Scalar ,
3184+ ArrayRef<std::tuple<Value *, User *, int >> ScalarUserAndIdx ) {
31843185 assert (Val->isVectorTy () && " This must be a vector type" );
31853186
31863187 if (Index != -1U ) {
@@ -3226,6 +3227,128 @@ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(const Instruction *I,
32263227 // compile-time considerations.
32273228 }
32283229
3230+ // In case of Neon, if there exists extractelement from lane != 0 such that
3231+ // 1. extractelement does not necessitate a move from vector_reg -> GPR.
3232+ // 2. extractelement result feeds into fmul.
3233+ // 3. Other operand of fmul is an extractelement from lane 0 or lane
3234+ // equivalent to 0.
3235+ // then the extractelement can be merged with fmul in the backend and it
3236+ // incurs no cost.
3237+ // e.g.
3238+ // define double @foo(<2 x double> %a) {
3239+ // %1 = extractelement <2 x double> %a, i32 0
3240+ // %2 = extractelement <2 x double> %a, i32 1
3241+ // %res = fmul double %1, %2
3242+ // ret double %res
3243+ // }
3244+ // %2 and %res can be merged in the backend to generate fmul d0, d0, v1.d[1]
3245+ auto ExtractCanFuseWithFmul = [&]() {
3246+ // We bail out if the extract is from lane 0.
3247+ if (Index == 0 )
3248+ return false ;
3249+
3250+ // Check if the scalar element type of the vector operand of ExtractElement
3251+ // instruction is one of the allowed types.
3252+ auto IsAllowedScalarTy = [&](const Type *T) {
3253+ return T->isFloatTy () || T->isDoubleTy () ||
3254+ (T->isHalfTy () && ST->hasFullFP16 ());
3255+ };
3256+
3257+ // Check if the extractelement user is scalar fmul.
3258+ auto IsUserFMulScalarTy = [](const Value *EEUser) {
3259+ // Check if the user is scalar fmul.
3260+ const auto *BO = dyn_cast_if_present<BinaryOperator>(EEUser);
3261+ return BO && BO->getOpcode () == BinaryOperator::FMul &&
3262+ !BO->getType ()->isVectorTy ();
3263+ };
3264+
3265+ // Check if the type constraints on input vector type and result scalar type
3266+ // of extractelement instruction are satisfied.
3267+ auto TypeConstraintsOnEESatisfied =
3268+ [&IsAllowedScalarTy](const Type *VectorTy, const Type *ScalarTy) {
3269+ return isa<FixedVectorType>(VectorTy) && IsAllowedScalarTy (ScalarTy);
3270+ };
3271+
3272+ // Check if the extract index is from lane 0 or lane equivalent to 0 for a
3273+ // certain scalar type and a certain vector register width.
3274+ auto IsExtractLaneEquivalentToZero = [&](const unsigned &Idx,
3275+ const unsigned &EltSz) {
3276+ auto RegWidth =
3277+ getRegisterBitWidth (TargetTransformInfo::RGK_FixedWidthVector)
3278+ .getFixedValue ();
3279+ return (Idx == 0 || (Idx * EltSz) % RegWidth == 0 );
3280+ };
3281+
3282+ if (Opcode.has_value ()) {
3283+ if (!TypeConstraintsOnEESatisfied (Val, Val->getScalarType ()))
3284+ return false ;
3285+
3286+ DenseMap<User *, unsigned > UserToExtractIdx;
3287+ for (auto *U : Scalar->users ()) {
3288+ if (!IsUserFMulScalarTy (U))
3289+ return false ;
3290+ // Recording entry for the user is important. Index value is not
3291+ // important.
3292+ UserToExtractIdx[U];
3293+ }
3294+ for (auto &[S, U, L] : ScalarUserAndIdx) {
3295+ for (auto *U : S->users ()) {
3296+ if (UserToExtractIdx.find (U) != UserToExtractIdx.end ()) {
3297+ auto *FMul = cast<BinaryOperator>(U);
3298+ auto *Op0 = FMul->getOperand (0 );
3299+ auto *Op1 = FMul->getOperand (1 );
3300+ if ((Op0 == S && Op1 == S) || (Op0 != S) || (Op1 != S)) {
3301+ UserToExtractIdx[U] = L;
3302+ break ;
3303+ }
3304+ }
3305+ }
3306+ }
3307+ for (auto &[U, L] : UserToExtractIdx) {
3308+ if (!IsExtractLaneEquivalentToZero (Index, Val->getScalarSizeInBits ()) &&
3309+ !IsExtractLaneEquivalentToZero (L, Val->getScalarSizeInBits ()))
3310+ return false ;
3311+ }
3312+ } else {
3313+ const auto *EE = cast<ExtractElementInst>(I);
3314+
3315+ const auto *IdxOp = dyn_cast<ConstantInt>(EE->getIndexOperand ());
3316+ if (!IdxOp)
3317+ return false ;
3318+
3319+ if (!TypeConstraintsOnEESatisfied (EE->getVectorOperand ()->getType (),
3320+ EE->getType ()))
3321+ return false ;
3322+
3323+ return !EE->users ().empty () && all_of (EE->users (), [&](const User *U) {
3324+ if (!IsUserFMulScalarTy (U))
3325+ return false ;
3326+
3327+ // Check if the other operand of extractelement is also extractelement
3328+ // from lane equivalent to 0.
3329+ const auto *BO = cast<BinaryOperator>(U);
3330+ const auto *OtherEE = dyn_cast<ExtractElementInst>(
3331+ BO->getOperand (0 ) == EE ? BO->getOperand (1 ) : BO->getOperand (0 ));
3332+ if (OtherEE) {
3333+ const auto *IdxOp = dyn_cast<ConstantInt>(OtherEE->getIndexOperand ());
3334+ if (!IdxOp)
3335+ return false ;
3336+ return IsExtractLaneEquivalentToZero (
3337+ cast<ConstantInt>(OtherEE->getIndexOperand ())
3338+ ->getValue ()
3339+ .getZExtValue (),
3340+ OtherEE->getType ()->getScalarSizeInBits ());
3341+ }
3342+ return true ;
3343+ });
3344+ }
3345+ return true ;
3346+ };
3347+
3348+ unsigned InstOpcode = I ? I->getOpcode () : Opcode.value ();
3349+ if (InstOpcode == Instruction::ExtractElement && ExtractCanFuseWithFmul ())
3350+ return 0 ;
3351+
32293352 // All other insert/extracts cost this much.
32303353 return ST->getVectorInsertExtractBaseCost ();
32313354}
@@ -3236,14 +3359,22 @@ InstructionCost AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
32363359 Value *Op1) {
32373360 bool HasRealUse =
32383361 Opcode == Instruction::InsertElement && Op0 && !isa<UndefValue>(Op0);
3239- return getVectorInstrCostHelper (nullptr , Val, Index, HasRealUse);
3362+ return getVectorInstrCostHelper (Val, Index, HasRealUse);
3363+ }
3364+
3365+ InstructionCost AArch64TTIImpl::getVectorInstrCost (
3366+ unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
3367+ Value *Scalar,
3368+ ArrayRef<std::tuple<Value *, User *, int >> ScalarUserAndIdx) {
3369+ return getVectorInstrCostHelper (Val, Index, false , nullptr , Opcode, Scalar,
3370+ ScalarUserAndIdx);
32403371}
32413372
32423373InstructionCost AArch64TTIImpl::getVectorInstrCost (const Instruction &I,
32433374 Type *Val,
32443375 TTI::TargetCostKind CostKind,
32453376 unsigned Index) {
3246- return getVectorInstrCostHelper (&I, Val, Index, true /* HasRealUse */ );
3377+ return getVectorInstrCostHelper (Val, Index, true /* HasRealUse */ , &I );
32473378}
32483379
32493380InstructionCost AArch64TTIImpl::getScalarizationOverhead (
0 commit comments