1111#include " AArch64PerfectShuffle.h"
1212#include " MCTargetDesc/AArch64AddressingModes.h"
1313#include " Utils/AArch64SMEAttributes.h"
14+ #include " llvm/ADT/DenseMap.h"
1415#include " llvm/Analysis/IVDescriptors.h"
1516#include " llvm/Analysis/LoopInfo.h"
1617#include " llvm/Analysis/TargetTransformInfo.h"
@@ -3177,10 +3178,10 @@ InstructionCost AArch64TTIImpl::getCFInstrCost(unsigned Opcode,
31773178 return 0 ;
31783179}
31793180
3180- InstructionCost AArch64TTIImpl::getVectorInstrCostHelper (const Instruction *I,
3181- Type *Val ,
3182- unsigned Index ,
3183- bool HasRealUse ) {
3181+ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper (
3182+ unsigned Opcode, Type *Val, unsigned Index, bool HasRealUse ,
3183+ const Instruction *I, Value *Scalar ,
3184+ ArrayRef<std::tuple<Value *, User *, int >> ScalarUserAndIdx ) {
31843185 assert (Val->isVectorTy () && " This must be a vector type" );
31853186
31863187 if (Index != -1U ) {
@@ -3226,6 +3227,119 @@ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(const Instruction *I,
32263227 // compile-time considerations.
32273228 }
32283229
3230+ // In case of Neon, if there exists extractelement from lane != 0 such that
3231+ // 1. extractelement does not necessitate a move from vector_reg -> GPR.
3232+ // 2. extractelement result feeds into fmul.
3233+ // 3. Other operand of fmul is an extractelement from lane 0 or lane
3234+ // equivalent to 0.
3235+ // then the extractelement can be merged with fmul in the backend and it
3236+ // incurs no cost.
3237+ // e.g.
3238+ // define double @foo(<2 x double> %a) {
3239+ // %1 = extractelement <2 x double> %a, i32 0
3240+ // %2 = extractelement <2 x double> %a, i32 1
3241+ // %res = fmul double %1, %2
3242+ // ret double %res
3243+ // }
3244+ // %2 and %res can be merged in the backend to generate fmul d0, d0, v1.d[1]
3245+ auto ExtractCanFuseWithFmul = [&]() {
3246+ // We bail out if the extract is from lane 0.
3247+ if (Index == 0 )
3248+ return false ;
3249+
3250+ // Check if the scalar element type of the vector operand of ExtractElement
3251+ // instruction is one of the allowed types.
3252+ auto IsAllowedScalarTy = [&](const Type *T) {
3253+ return T->isFloatTy () || T->isDoubleTy () ||
3254+ (T->isHalfTy () && ST->hasFullFP16 ());
3255+ };
3256+
3257+ // Check if the extractelement user is scalar fmul.
3258+ auto IsUserFMulScalarTy = [](const Value *EEUser) {
3259+ // Check if the user is scalar fmul.
3260+ const auto *BO = dyn_cast_if_present<BinaryOperator>(EEUser);
3261+ return BO && BO->getOpcode () == BinaryOperator::FMul &&
3262+ !BO->getType ()->isVectorTy ();
3263+ };
3264+
3265+ // Check if the extract index is from lane 0 or lane equivalent to 0 for a
3266+ // certain scalar type and a certain vector register width.
3267+ auto IsExtractLaneEquivalentToZero = [&](const unsigned &Idx,
3268+ const unsigned &EltSz) {
3269+ auto RegWidth =
3270+ getRegisterBitWidth (TargetTransformInfo::RGK_FixedWidthVector)
3271+ .getFixedValue ();
3272+ return (Idx == 0 || (Idx * EltSz) % RegWidth == 0 );
3273+ };
3274+
3275+ // Check if the type constraints on input vector type and result scalar type
3276+ // of extractelement instruction are satisfied.
3277+ if (!isa<FixedVectorType>(Val) || !IsAllowedScalarTy (Val->getScalarType ()))
3278+ return false ;
3279+
3280+ if (Scalar) {
3281+ DenseMap<User *, unsigned > UserToExtractIdx;
3282+ for (auto *U : Scalar->users ()) {
3283+ if (!IsUserFMulScalarTy (U))
3284+ return false ;
3285+ // Recording entry for the user is important. Index value is not
3286+ // important.
3287+ UserToExtractIdx[U];
3288+ }
3289+ for (auto &[S, U, L] : ScalarUserAndIdx) {
3290+ for (auto *U : S->users ()) {
3291+ if (UserToExtractIdx.find (U) != UserToExtractIdx.end ()) {
3292+ auto *FMul = cast<BinaryOperator>(U);
3293+ auto *Op0 = FMul->getOperand (0 );
3294+ auto *Op1 = FMul->getOperand (1 );
3295+ if ((Op0 == S && Op1 == S) || (Op0 != S) || (Op1 != S)) {
3296+ UserToExtractIdx[U] = L;
3297+ break ;
3298+ }
3299+ }
3300+ }
3301+ }
3302+ for (auto &[U, L] : UserToExtractIdx) {
3303+ if (!IsExtractLaneEquivalentToZero (Index, Val->getScalarSizeInBits ()) &&
3304+ !IsExtractLaneEquivalentToZero (L, Val->getScalarSizeInBits ()))
3305+ return false ;
3306+ }
3307+ } else {
3308+ const auto *EE = cast<ExtractElementInst>(I);
3309+
3310+ const auto *IdxOp = dyn_cast<ConstantInt>(EE->getIndexOperand ());
3311+ if (!IdxOp)
3312+ return false ;
3313+
3314+ return !EE->users ().empty () && all_of (EE->users (), [&](const User *U) {
3315+ if (!IsUserFMulScalarTy (U))
3316+ return false ;
3317+
3318+ // Check if the other operand of extractelement is also extractelement
3319+ // from lane equivalent to 0.
3320+ const auto *BO = cast<BinaryOperator>(U);
3321+ const auto *OtherEE = dyn_cast<ExtractElementInst>(
3322+ BO->getOperand (0 ) == EE ? BO->getOperand (1 ) : BO->getOperand (0 ));
3323+ if (OtherEE) {
3324+ const auto *IdxOp = dyn_cast<ConstantInt>(OtherEE->getIndexOperand ());
3325+ if (!IdxOp)
3326+ return false ;
3327+ return IsExtractLaneEquivalentToZero (
3328+ cast<ConstantInt>(OtherEE->getIndexOperand ())
3329+ ->getValue ()
3330+ .getZExtValue (),
3331+ OtherEE->getType ()->getScalarSizeInBits ());
3332+ }
3333+ return true ;
3334+ });
3335+ }
3336+ return true ;
3337+ };
3338+
3339+ if (Opcode == Instruction::ExtractElement && (I || Scalar) &&
3340+ ExtractCanFuseWithFmul ())
3341+ return 0 ;
3342+
32293343 // All other insert/extracts cost this much.
32303344 return ST->getVectorInsertExtractBaseCost ();
32313345}
@@ -3236,14 +3350,23 @@ InstructionCost AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
32363350 Value *Op1) {
32373351 bool HasRealUse =
32383352 Opcode == Instruction::InsertElement && Op0 && !isa<UndefValue>(Op0);
3239- return getVectorInstrCostHelper (nullptr , Val, Index, HasRealUse);
3353+ return getVectorInstrCostHelper (Opcode, Val, Index, HasRealUse);
3354+ }
3355+
3356+ InstructionCost AArch64TTIImpl::getVectorInstrCost (
3357+ unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
3358+ Value *Scalar,
3359+ ArrayRef<std::tuple<Value *, User *, int >> ScalarUserAndIdx) {
3360+ return getVectorInstrCostHelper (Opcode, Val, Index, false , nullptr , Scalar,
3361+ ScalarUserAndIdx);
32403362}
32413363
32423364InstructionCost AArch64TTIImpl::getVectorInstrCost (const Instruction &I,
32433365 Type *Val,
32443366 TTI::TargetCostKind CostKind,
32453367 unsigned Index) {
3246- return getVectorInstrCostHelper (&I, Val, Index, true /* HasRealUse */ );
3368+ return getVectorInstrCostHelper (I.getOpcode (), Val, Index,
3369+ true /* HasRealUse */ , &I);
32473370}
32483371
32493372InstructionCost AArch64TTIImpl::getScalarizationOverhead (
0 commit comments