1616#include " llvm/CodeGen/BasicTTIImpl.h"
1717#include " llvm/CodeGen/CostTable.h"
1818#include " llvm/CodeGen/TargetLowering.h"
19+ #include " llvm/IR/DerivedTypes.h"
20+ #include " llvm/IR/InstrTypes.h"
21+ #include " llvm/IR/Instruction.h"
22+ #include " llvm/IR/Instructions.h"
1923#include " llvm/IR/IntrinsicInst.h"
2024#include " llvm/IR/Intrinsics.h"
2125#include " llvm/IR/IntrinsicsAArch64.h"
2226#include " llvm/IR/PatternMatch.h"
27+ #include " llvm/Support/Casting.h"
2328#include " llvm/Support/Debug.h"
2429#include " llvm/Transforms/InstCombine/InstCombiner.h"
2530#include " llvm/Transforms/Vectorize/LoopVectorizationLegality.h"
2631#include < algorithm>
32+ #include < cassert>
2733#include < optional>
2834using namespace llvm ;
2935using namespace llvm ::PatternMatch;
@@ -3145,12 +3151,20 @@ InstructionCost AArch64TTIImpl::getCFInstrCost(unsigned Opcode,
31453151 return 0 ;
31463152}
31473153
3148- InstructionCost AArch64TTIImpl::getVectorInstrCostHelper (const Instruction *I,
3149- Type *Val,
3150- unsigned Index,
3151- bool HasRealUse) {
3154+ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper (
3155+ std::variant<const Instruction *, const unsigned > InstOrOpcode, Type *Val,
3156+ unsigned Index, bool HasRealUse, Value *Scalar,
3157+ const DenseMap<std::pair<Value *, unsigned >, SmallVector<Value *, 4 >>
3158+ &ScalarAndIdxToUser,
3159+ const DenseMap<Value *, SmallVector<std::pair<Value *, unsigned >, 4 >>
3160+ &UserToScalarAndIdx) {
31523161 assert (Val->isVectorTy () && " This must be a vector type" );
31533162
3163+ const Instruction *I =
3164+ (std::holds_alternative<const Instruction *>(InstOrOpcode)
3165+ ? get<const Instruction *>(InstOrOpcode)
3166+ : nullptr );
3167+
31543168 if (Index != -1U ) {
31553169 // Legalize the type.
31563170 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost (Val);
@@ -3194,6 +3208,134 @@ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(const Instruction *I,
31943208 // compile-time considerations.
31953209 }
31963210
3211+ // In case of Neon, if there exists extractelement from lane != 0 such that
3212+ // 1. extractelement does not necessitate a move from vector_reg -> GPR.
3213+ // 2. extractelement result feeds into fmul.
3214+ // 3. Other operand of fmul is a scalar or extractelement from lane 0 or lane
3215+ // equivalent to 0.
3216+ // then the extractelement can be merged with fmul in the backend and it
3217+ // incurs no cost.
3218+ // e.g.
3219+ // define double @foo(<2 x double> %a) {
3220+ // %1 = extractelement <2 x double> %a, i32 0
3221+ // %2 = extractelement <2 x double> %a, i32 1
3222+ // %res = fmul double %1, %2
3223+ // ret double %res
3224+ // }
3225+ // %2 and %res can be merged in the backend to generate fmul v0, v0, v1.d[1]
3226+ auto ExtractCanFuseWithFmul = [&]() {
3227+ // We bail out if the extract is from lane 0.
3228+ if (Index == 0 )
3229+ return false ;
3230+
3231+ // Check if the scalar element type of the vector operand of ExtractElement
3232+ // instruction is one of the allowed types.
3233+ auto IsAllowedScalarTy = [&](const Type *T) {
3234+ return T->isFloatTy () || T->isDoubleTy () ||
3235+ (T->isHalfTy () && ST->hasFullFP16 ());
3236+ };
3237+
3238+ // Check if the extractelement user is scalar fmul.
3239+ auto IsUserFMulScalarTy = [](const Value *EEUser) {
3240+ // Check if the user is scalar fmul.
3241+ const BinaryOperator *BO = dyn_cast_if_present<BinaryOperator>(EEUser);
3242+ return BO && BO->getOpcode () == BinaryOperator::FMul &&
3243+ !BO->getType ()->isVectorTy ();
3244+ };
3245+
3246+ // InstCombine combines fmul with fadd/fsub. Hence, extractelement fusion
3247+ // with fmul does not happen.
3248+ auto IsFMulUserFAddFSub = [](const Value *FMul) {
3249+ return any_of (FMul->users (), [](const User *U) {
3250+ const BinaryOperator *BO = dyn_cast_if_present<BinaryOperator>(U);
3251+ return (BO && (BO->getOpcode () == BinaryOperator::FAdd ||
3252+ BO->getOpcode () == BinaryOperator::FSub));
3253+ });
3254+ };
3255+
3256+ // Check if the type constraints on input vector type and result scalar type
3257+ // of extractelement instruction are satisfied.
3258+ auto TypeConstraintsOnEESatisfied =
3259+ [&IsAllowedScalarTy](const Type *VectorTy, const Type *ScalarTy) {
3260+ return isa<FixedVectorType>(VectorTy) && IsAllowedScalarTy (ScalarTy);
3261+ };
3262+
3263+ // Check if the extract index is from lane 0 or lane equivalent to 0 for a
3264+ // certain scalar type and a certain vector register width.
3265+ auto IsExtractLaneEquivalentToZero = [&](const unsigned &Idx,
3266+ const unsigned &EltSz) {
3267+ auto RegWidth =
3268+ getRegisterBitWidth (TargetTransformInfo::RGK_FixedWidthVector)
3269+ .getFixedValue ();
3270+ return (Idx == 0 || (Idx * EltSz) % RegWidth == 0 );
3271+ };
3272+
3273+ if (std::holds_alternative<const unsigned >(InstOrOpcode)) {
3274+ if (!TypeConstraintsOnEESatisfied (Val, Val->getScalarType ()))
3275+ return false ;
3276+ const auto &ScalarIdxPair = std::make_pair (Scalar, Index);
3277+ return ScalarAndIdxToUser.find (ScalarIdxPair) !=
3278+ ScalarAndIdxToUser.end () &&
3279+ all_of (ScalarAndIdxToUser.at (ScalarIdxPair), [&](Value *U) {
3280+ if (!IsUserFMulScalarTy (U) || IsFMulUserFAddFSub (U))
3281+ return false ;
3282+ // 1. Check if the other operand is extract from lane 0 or lane
3283+ // equivalent to 0.
3284+ // 2. In case of SLP, if the other operand is not extract from
3285+ // same tree, we bail out since we can not analyze that extract.
3286+ return UserToScalarAndIdx.at (U).size () == 2 &&
3287+ all_of (UserToScalarAndIdx.at (U), [&](auto &P) {
3288+ if (ScalarIdxPair == P)
3289+ return true ; // Skip.
3290+ return IsExtractLaneEquivalentToZero (
3291+ P.second , Val->getScalarSizeInBits ());
3292+ });
3293+ });
3294+ } else {
3295+ const ExtractElementInst *EE = cast<ExtractElementInst>(I);
3296+
3297+ const ConstantInt *IdxOp = dyn_cast<ConstantInt>(EE->getIndexOperand ());
3298+ if (!IdxOp)
3299+ return false ;
3300+
3301+ if (!TypeConstraintsOnEESatisfied (EE->getVectorOperand ()->getType (),
3302+ EE->getType ()))
3303+ return false ;
3304+
3305+ return !EE->users ().empty () && all_of (EE->users (), [&](const User *U) {
3306+ if (!IsUserFMulScalarTy (U) || IsFMulUserFAddFSub (U))
3307+ return false ;
3308+
3309+ // Check if the other operand of extractelement is also extractelement
3310+ // from lane equivalent to 0.
3311+ const BinaryOperator *BO = cast<BinaryOperator>(U);
3312+ const ExtractElementInst *OtherEE = dyn_cast<ExtractElementInst>(
3313+ BO->getOperand (0 ) == EE ? BO->getOperand (1 ) : BO->getOperand (0 ));
3314+ if (OtherEE) {
3315+ const ConstantInt *IdxOp =
3316+ dyn_cast<ConstantInt>(OtherEE->getIndexOperand ());
3317+ if (!IdxOp)
3318+ return false ;
3319+ return IsExtractLaneEquivalentToZero (
3320+ cast<ConstantInt>(OtherEE->getIndexOperand ())
3321+ ->getValue ()
3322+ .getZExtValue (),
3323+ OtherEE->getType ()->getScalarSizeInBits ());
3324+ }
3325+ return true ;
3326+ });
3327+ }
3328+ return false ;
3329+ };
3330+
3331+ if (std::holds_alternative<const unsigned >(InstOrOpcode)) {
3332+ const unsigned &Opcode = get<const unsigned >(InstOrOpcode);
3333+ if (Opcode == Instruction::ExtractElement && ExtractCanFuseWithFmul ())
3334+ return 0 ;
3335+ } else if (I && I->getOpcode () == Instruction::ExtractElement &&
3336+ ExtractCanFuseWithFmul ())
3337+ return 0 ;
3338+
31973339 // All other insert/extracts cost this much.
31983340 return ST->getVectorInsertExtractBaseCost ();
31993341}
@@ -3207,6 +3349,19 @@ InstructionCost AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
32073349 return getVectorInstrCostHelper (nullptr , Val, Index, HasRealUse);
32083350}
32093351
3352+ InstructionCost AArch64TTIImpl::getVectorInstrCost (
3353+ unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
3354+ Value *Op0, Value *Op1, Value *Scalar,
3355+ const DenseMap<std::pair<Value *, unsigned >, SmallVector<Value *, 4 >>
3356+ &ScalarAndIdxToUser,
3357+ const DenseMap<Value *, SmallVector<std::pair<Value *, unsigned >, 4 >>
3358+ &UserToScalarAndIdx) {
3359+ bool HasRealUse =
3360+ Opcode == Instruction::InsertElement && Op0 && !isa<UndefValue>(Op0);
3361+ return getVectorInstrCostHelper (Opcode, Val, Index, HasRealUse, Scalar,
3362+ ScalarAndIdxToUser, UserToScalarAndIdx);
3363+ }
3364+
32103365InstructionCost AArch64TTIImpl::getVectorInstrCost (const Instruction &I,
32113366 Type *Val,
32123367 TTI::TargetCostKind CostKind,
0 commit comments