1010#include " AArch64ExpandImm.h"
1111#include " AArch64PerfectShuffle.h"
1212#include " MCTargetDesc/AArch64AddressingModes.h"
13+ #include " llvm/ADT/STLExtras.h"
1314#include " llvm/Analysis/IVDescriptors.h"
1415#include " llvm/Analysis/LoopInfo.h"
1516#include " llvm/Analysis/TargetTransformInfo.h"
1617#include " llvm/CodeGen/BasicTTIImpl.h"
1718#include " llvm/CodeGen/CostTable.h"
1819#include " llvm/CodeGen/TargetLowering.h"
20+ #include " llvm/IR/DerivedTypes.h"
21+ #include " llvm/IR/InstrTypes.h"
22+ #include " llvm/IR/Instruction.h"
23+ #include " llvm/IR/Instructions.h"
1924#include " llvm/IR/IntrinsicInst.h"
2025#include " llvm/IR/Intrinsics.h"
2126#include " llvm/IR/IntrinsicsAArch64.h"
2227#include " llvm/IR/PatternMatch.h"
28+ #include " llvm/IR/User.h"
29+ #include " llvm/Support/Casting.h"
2330#include " llvm/Support/Debug.h"
2431#include " llvm/Transforms/InstCombine/InstCombiner.h"
2532#include " llvm/Transforms/Vectorize/LoopVectorizationLegality.h"
2633#include < algorithm>
34+ // #include <cassert>
2735#include < optional>
2836using namespace llvm ;
2937using namespace llvm ::PatternMatch;
@@ -3145,12 +3153,16 @@ InstructionCost AArch64TTIImpl::getCFInstrCost(unsigned Opcode,
31453153 return 0 ;
31463154}
31473155
3148- InstructionCost AArch64TTIImpl::getVectorInstrCostHelper (const Instruction *I,
3149- Type *Val,
3150- unsigned Index ,
3151- bool HasRealUse ) {
3156+ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper (
3157+ std::variant< const Instruction *, const unsigned > InstOrOpcode, Type *Val,
3158+ unsigned Index, bool HasRealUse, Value *Scalar ,
3159+ ArrayRef<std::tuple<Value *, User *, int >> ScalarUserAndIdx ) {
31523160 assert (Val->isVectorTy () && " This must be a vector type" );
31533161
3162+ const auto *I = (std::holds_alternative<const Instruction *>(InstOrOpcode)
3163+ ? get<const Instruction *>(InstOrOpcode)
3164+ : nullptr );
3165+
31543166 if (Index != -1U ) {
31553167 // Legalize the type.
31563168 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost (Val);
@@ -3194,6 +3206,149 @@ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(const Instruction *I,
31943206 // compile-time considerations.
31953207 }
31963208
3209+ // In case of Neon, if there exists extractelement from lane != 0 such that
3210+ // 1. extractelement does not necessitate a move from vector_reg -> GPR.
3211+ // 2. extractelement result feeds into fmul.
3212+ // 3. Other operand of fmul is a scalar or extractelement from lane 0 or lane
3213+ // equivalent to 0.
3214+ // then the extractelement can be merged with fmul in the backend and it
3215+ // incurs no cost.
3216+ // e.g.
3217+ // define double @foo(<2 x double> %a) {
3218+ // %1 = extractelement <2 x double> %a, i32 0
3219+ // %2 = extractelement <2 x double> %a, i32 1
3220+ // %res = fmul double %1, %2
3221+ // ret double %res
3222+ // }
3223+ // %2 and %res can be merged in the backend to generate fmul v0, v0, v1.d[1]
3224+ auto ExtractCanFuseWithFmul = [&]() {
3225+ // We bail out if the extract is from lane 0.
3226+ if (Index == 0 )
3227+ return false ;
3228+
3229+ // Check if the scalar element type of the vector operand of ExtractElement
3230+ // instruction is one of the allowed types.
3231+ auto IsAllowedScalarTy = [&](const Type *T) {
3232+ return T->isFloatTy () || T->isDoubleTy () ||
3233+ (T->isHalfTy () && ST->hasFullFP16 ());
3234+ };
3235+
3236+ // Check if the extractelement user is scalar fmul.
3237+ auto IsUserFMulScalarTy = [](const Value *EEUser) {
3238+ // Check if the user is scalar fmul.
3239+ const auto *BO = dyn_cast_if_present<BinaryOperator>(EEUser);
3240+ return BO && BO->getOpcode () == BinaryOperator::FMul &&
3241+ !BO->getType ()->isVectorTy ();
3242+ };
3243+
3244+ // InstCombine combines fmul with fadd/fsub. Hence, extractelement fusion
3245+ // with fmul does not happen.
3246+ auto IsFMulUserFAddFSub = [](const Value *FMul) {
3247+ return any_of (FMul->users (), [](const User *U) {
3248+ const auto *BO = dyn_cast_if_present<BinaryOperator>(U);
3249+ return (BO && (BO->getOpcode () == BinaryOperator::FAdd ||
3250+ BO->getOpcode () == BinaryOperator::FSub));
3251+ });
3252+ };
3253+
3254+ // Check if the type constraints on input vector type and result scalar type
3255+ // of extractelement instruction are satisfied.
3256+ auto TypeConstraintsOnEESatisfied =
3257+ [&IsAllowedScalarTy](const Type *VectorTy, const Type *ScalarTy) {
3258+ return isa<FixedVectorType>(VectorTy) && IsAllowedScalarTy (ScalarTy);
3259+ };
3260+
3261+ // Check if the extract index is from lane 0 or lane equivalent to 0 for a
3262+ // certain scalar type and a certain vector register width.
3263+ auto IsExtractLaneEquivalentToZero = [&](const unsigned &Idx,
3264+ const unsigned &EltSz) {
3265+ auto RegWidth =
3266+ getRegisterBitWidth (TargetTransformInfo::RGK_FixedWidthVector)
3267+ .getFixedValue ();
3268+ return (Idx == 0 || (Idx * EltSz) % RegWidth == 0 );
3269+ };
3270+
3271+ if (std::holds_alternative<const unsigned >(InstOrOpcode)) {
3272+ if (!TypeConstraintsOnEESatisfied (Val, Val->getScalarType ()))
3273+ return false ;
3274+
3275+ for (auto &RefT : ScalarUserAndIdx) {
3276+ Value *RefS = get<0 >(RefT);
3277+ User *RefU = get<1 >(RefT);
3278+ const int &RefL = get<2 >(RefT);
3279+
3280+ // Analayze all the users which have same scalar/index as Scalar/Index.
3281+ if (RefS != Scalar || RefL != Index)
3282+ continue ;
3283+
3284+ // Check if the user of {Scalar, Index} pair is fmul user.
3285+ if (!IsUserFMulScalarTy (RefU) || IsFMulUserFAddFSub (RefU))
3286+ return false ;
3287+
3288+ // For RefU, check if the other operand is extract from the same SLP
3289+ // tree. If not, we bail out since we can't analyze extracts from other
3290+ // SLP tree.
3291+ unsigned NumExtractEltsIntoUser = 0 ;
3292+ for (auto &CmpT : ScalarUserAndIdx) {
3293+ User *CmpU = get<1 >(CmpT);
3294+ if (CmpT == RefT || CmpU != RefU)
3295+ continue ;
3296+ Value *CmpS = get<0 >(CmpT);
3297+ ++NumExtractEltsIntoUser;
3298+ const int &CmpL = get<2 >(CmpT);
3299+ if (!IsExtractLaneEquivalentToZero (CmpL, Val->getScalarSizeInBits ()))
3300+ return false ;
3301+ }
3302+ // We know this is fmul user with just 2 operands, one being RefT. If we
3303+ // can't find CmpT, as the other operand, then bail out.
3304+ if (NumExtractEltsIntoUser != 1 )
3305+ return false ;
3306+ }
3307+ } else {
3308+ const auto *EE = cast<ExtractElementInst>(I);
3309+
3310+ const auto *IdxOp = dyn_cast<ConstantInt>(EE->getIndexOperand ());
3311+ if (!IdxOp)
3312+ return false ;
3313+
3314+ if (!TypeConstraintsOnEESatisfied (EE->getVectorOperand ()->getType (),
3315+ EE->getType ()))
3316+ return false ;
3317+
3318+ return !EE->users ().empty () && all_of (EE->users (), [&](const User *U) {
3319+ if (!IsUserFMulScalarTy (U) || IsFMulUserFAddFSub (U))
3320+ return false ;
3321+
3322+ // Check if the other operand of extractelement is also extractelement
3323+ // from lane equivalent to 0.
3324+ const auto *BO = cast<BinaryOperator>(U);
3325+ const auto *OtherEE = dyn_cast<ExtractElementInst>(
3326+ BO->getOperand (0 ) == EE ? BO->getOperand (1 ) : BO->getOperand (0 ));
3327+ if (OtherEE) {
3328+ const auto *IdxOp = dyn_cast<ConstantInt>(OtherEE->getIndexOperand ());
3329+ if (!IdxOp)
3330+ return false ;
3331+ return IsExtractLaneEquivalentToZero (
3332+ cast<ConstantInt>(OtherEE->getIndexOperand ())
3333+ ->getValue ()
3334+ .getZExtValue (),
3335+ OtherEE->getType ()->getScalarSizeInBits ());
3336+ }
3337+ return true ;
3338+ });
3339+ }
3340+ return true ;
3341+ };
3342+
3343+ if (std::holds_alternative<const unsigned >(InstOrOpcode)) {
3344+ const unsigned &Opcode = get<const unsigned >(InstOrOpcode);
3345+ if (Opcode == Instruction::ExtractElement && ExtractCanFuseWithFmul ())
3346+ return 0 ;
3347+ } else if (I && I->getOpcode () == Instruction::ExtractElement &&
3348+ ExtractCanFuseWithFmul ()) {
3349+ return 0 ;
3350+ }
3351+
31973352 // All other insert/extracts cost this much.
31983353 return ST->getVectorInsertExtractBaseCost ();
31993354}
@@ -3207,6 +3362,14 @@ InstructionCost AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
32073362 return getVectorInstrCostHelper (nullptr , Val, Index, HasRealUse);
32083363}
32093364
3365+ InstructionCost AArch64TTIImpl::getVectorInstrCost (
3366+ unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
3367+ Value *Scalar,
3368+ ArrayRef<std::tuple<Value *, User *, int >> ScalarUserAndIdx) {
3369+ return getVectorInstrCostHelper (Opcode, Val, Index, false , Scalar,
3370+ ScalarUserAndIdx);
3371+ }
3372+
32103373InstructionCost AArch64TTIImpl::getVectorInstrCost (const Instruction &I,
32113374 Type *Val,
32123375 TTI::TargetCostKind CostKind,
0 commit comments