1010#include " AArch64ExpandImm.h"
1111#include " AArch64PerfectShuffle.h"
1212#include " MCTargetDesc/AArch64AddressingModes.h"
13+ #include " llvm/ADT/DenseMap.h"
14+ #include " llvm/ADT/STLExtras.h"
1315#include " llvm/Analysis/IVDescriptors.h"
1416#include " llvm/Analysis/LoopInfo.h"
1517#include " llvm/Analysis/TargetTransformInfo.h"
1618#include " llvm/CodeGen/BasicTTIImpl.h"
1719#include " llvm/CodeGen/CostTable.h"
1820#include " llvm/CodeGen/TargetLowering.h"
21+ #include " llvm/IR/DerivedTypes.h"
22+ #include " llvm/IR/InstrTypes.h"
23+ #include " llvm/IR/Instruction.h"
24+ #include " llvm/IR/Instructions.h"
1925#include " llvm/IR/IntrinsicInst.h"
2026#include " llvm/IR/Intrinsics.h"
2127#include " llvm/IR/IntrinsicsAArch64.h"
2228#include " llvm/IR/PatternMatch.h"
29+ #include " llvm/IR/User.h"
30+ #include " llvm/Support/Casting.h"
2331#include " llvm/Support/Debug.h"
2432#include " llvm/Transforms/InstCombine/InstCombiner.h"
2533#include " llvm/Transforms/Vectorize/LoopVectorizationLegality.h"
2634#include < algorithm>
35+ // #include <cassert>
2736#include < optional>
2837using namespace llvm ;
2938using namespace llvm ::PatternMatch;
@@ -3145,12 +3154,16 @@ InstructionCost AArch64TTIImpl::getCFInstrCost(unsigned Opcode,
31453154 return 0 ;
31463155}
31473156
3148- InstructionCost AArch64TTIImpl::getVectorInstrCostHelper (const Instruction *I,
3149- Type *Val,
3150- unsigned Index ,
3151- bool HasRealUse ) {
3157+ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper (
3158+ std::variant< const Instruction *, const unsigned > InstOrOpcode, Type *Val,
3159+ unsigned Index, bool HasRealUse, Value *Scalar ,
3160+ ArrayRef<std::tuple<Value *, User *, int >> ScalarUserAndIdx ) {
31523161 assert (Val->isVectorTy () && " This must be a vector type" );
31533162
3163+ const auto *I = (std::holds_alternative<const Instruction *>(InstOrOpcode)
3164+ ? get<const Instruction *>(InstOrOpcode)
3165+ : nullptr );
3166+
31543167 if (Index != -1U ) {
31553168 // Legalize the type.
31563169 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost (Val);
@@ -3194,6 +3207,143 @@ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(const Instruction *I,
31943207 // compile-time considerations.
31953208 }
31963209
3210+ // In case of Neon, if there exists extractelement from lane != 0 such that
3211+ // 1. extractelement does not necessitate a move from vector_reg -> GPR.
3212+ // 2. extractelement result feeds into fmul.
3213+ // 3. Other operand of fmul is an extractelement from lane 0 or lane
3214+ // equivalent to 0.
3215+ // then the extractelement can be merged with fmul in the backend and it
3216+ // incurs no cost.
3217+ // e.g.
3218+ // define double @foo(<2 x double> %a) {
3219+ // %1 = extractelement <2 x double> %a, i32 0
3220+ // %2 = extractelement <2 x double> %a, i32 1
3221+ // %res = fmul double %1, %2
3222+ // ret double %res
3223+ // }
3224+ // %2 and %res can be merged in the backend to generate fmul d0, d0, v1.d[1]
3225+ auto ExtractCanFuseWithFmul = [&]() {
3226+ // We bail out if the extract is from lane 0.
3227+ if (Index == 0 )
3228+ return false ;
3229+
3230+ // Check if the scalar element type of the vector operand of ExtractElement
3231+ // instruction is one of the allowed types.
3232+ auto IsAllowedScalarTy = [&](const Type *T) {
3233+ return T->isFloatTy () || T->isDoubleTy () ||
3234+ (T->isHalfTy () && ST->hasFullFP16 ());
3235+ };
3236+
3237+ // Check if the extractelement user is scalar fmul.
3238+ auto IsUserFMulScalarTy = [](const Value *EEUser) {
3239+ // Check if the user is scalar fmul.
3240+ const auto *BO = dyn_cast_if_present<BinaryOperator>(EEUser);
3241+ return BO && BO->getOpcode () == BinaryOperator::FMul &&
3242+ !BO->getType ()->isVectorTy ();
3243+ };
3244+
3245+ // InstCombine combines fmul with fadd/fsub. Hence, extractelement fusion
3246+ // with fmul does not happen.
3247+ auto IsFMulUserFAddFSub = [](const Value *FMul) {
3248+ return any_of (FMul->users (), [](const User *U) {
3249+ const auto *BO = dyn_cast_if_present<BinaryOperator>(U);
3250+ return (BO && (BO->getOpcode () == BinaryOperator::FAdd ||
3251+ BO->getOpcode () == BinaryOperator::FSub));
3252+ });
3253+ };
3254+
3255+ // Check if the type constraints on input vector type and result scalar type
3256+ // of extractelement instruction are satisfied.
3257+ auto TypeConstraintsOnEESatisfied =
3258+ [&IsAllowedScalarTy](const Type *VectorTy, const Type *ScalarTy) {
3259+ return isa<FixedVectorType>(VectorTy) && IsAllowedScalarTy (ScalarTy);
3260+ };
3261+
3262+ // Check if the extract index is from lane 0 or lane equivalent to 0 for a
3263+ // certain scalar type and a certain vector register width.
3264+ auto IsExtractLaneEquivalentToZero = [&](const unsigned &Idx,
3265+ const unsigned &EltSz) {
3266+ auto RegWidth =
3267+ getRegisterBitWidth (TargetTransformInfo::RGK_FixedWidthVector)
3268+ .getFixedValue ();
3269+ return (Idx == 0 || (Idx * EltSz) % RegWidth == 0 );
3270+ };
3271+
3272+ if (std::holds_alternative<const unsigned >(InstOrOpcode)) {
3273+ if (!TypeConstraintsOnEESatisfied (Val, Val->getScalarType ()))
3274+ return false ;
3275+
3276+ DenseMap<User *, unsigned > UserToExtractIdx;
3277+ for (auto *U : Scalar->users ()) {
3278+ if (!IsUserFMulScalarTy (U) || IsFMulUserFAddFSub (U))
3279+ return false ;
3280+ // Recording entry for the user is important. Index value is not
3281+ // important.
3282+ UserToExtractIdx[U];
3283+ }
3284+ for (auto &[S, U, L] : ScalarUserAndIdx) {
3285+ for (auto *U : S->users ()) {
3286+ if (UserToExtractIdx.find (U) != UserToExtractIdx.end ()) {
3287+ auto *FMul = cast<BinaryOperator>(U);
3288+ auto *Op0 = FMul->getOperand (0 );
3289+ auto *Op1 = FMul->getOperand (1 );
3290+ if ((Op0 == S && Op1 == S) || (Op0 != S) || (Op1 != S)) {
3291+ UserToExtractIdx[U] = L;
3292+ break ;
3293+ }
3294+ }
3295+ }
3296+ }
3297+ for (auto &[U, L] : UserToExtractIdx) {
3298+ if (!IsExtractLaneEquivalentToZero (Index, Val->getScalarSizeInBits ()) &&
3299+ !IsExtractLaneEquivalentToZero (L, Val->getScalarSizeInBits ()))
3300+ return false ;
3301+ }
3302+ } else {
3303+ const auto *EE = cast<ExtractElementInst>(I);
3304+
3305+ const auto *IdxOp = dyn_cast<ConstantInt>(EE->getIndexOperand ());
3306+ if (!IdxOp)
3307+ return false ;
3308+
3309+ if (!TypeConstraintsOnEESatisfied (EE->getVectorOperand ()->getType (),
3310+ EE->getType ()))
3311+ return false ;
3312+
3313+ return !EE->users ().empty () && all_of (EE->users (), [&](const User *U) {
3314+ if (!IsUserFMulScalarTy (U) || IsFMulUserFAddFSub (U))
3315+ return false ;
3316+
3317+ // Check if the other operand of extractelement is also extractelement
3318+ // from lane equivalent to 0.
3319+ const auto *BO = cast<BinaryOperator>(U);
3320+ const auto *OtherEE = dyn_cast<ExtractElementInst>(
3321+ BO->getOperand (0 ) == EE ? BO->getOperand (1 ) : BO->getOperand (0 ));
3322+ if (OtherEE) {
3323+ const auto *IdxOp = dyn_cast<ConstantInt>(OtherEE->getIndexOperand ());
3324+ if (!IdxOp)
3325+ return false ;
3326+ return IsExtractLaneEquivalentToZero (
3327+ cast<ConstantInt>(OtherEE->getIndexOperand ())
3328+ ->getValue ()
3329+ .getZExtValue (),
3330+ OtherEE->getType ()->getScalarSizeInBits ());
3331+ }
3332+ return true ;
3333+ });
3334+ }
3335+ return true ;
3336+ };
3337+
3338+ if (std::holds_alternative<const unsigned >(InstOrOpcode)) {
3339+ const unsigned &Opcode = get<const unsigned >(InstOrOpcode);
3340+ if (Opcode == Instruction::ExtractElement && ExtractCanFuseWithFmul ())
3341+ return 0 ;
3342+ } else if (I && I->getOpcode () == Instruction::ExtractElement &&
3343+ ExtractCanFuseWithFmul ()) {
3344+ return 0 ;
3345+ }
3346+
31973347 // All other insert/extracts cost this much.
31983348 return ST->getVectorInsertExtractBaseCost ();
31993349}
@@ -3207,6 +3357,14 @@ InstructionCost AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
32073357 return getVectorInstrCostHelper (nullptr , Val, Index, HasRealUse);
32083358}
32093359
3360+ InstructionCost AArch64TTIImpl::getVectorInstrCost (
3361+ unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
3362+ Value *Scalar,
3363+ ArrayRef<std::tuple<Value *, User *, int >> ScalarUserAndIdx) {
3364+ return getVectorInstrCostHelper (Opcode, Val, Index, false , Scalar,
3365+ ScalarUserAndIdx);
3366+ }
3367+
32103368InstructionCost AArch64TTIImpl::getVectorInstrCost (const Instruction &I,
32113369 Type *Val,
32123370 TTI::TargetCostKind CostKind,
0 commit comments