Skip to content

Commit 08e2879

Browse files
committed
[CostModel][AArch64] Make extractelement, with fmul user, free whenever possible
In case of Neon, if there exists extractelement from lane != 0 such that 1. extractelement does not necessitate a move from vector_reg -> GPR. 2. extractelement result feeds into fmul. 3. Other operand of fmul is a scalar or extractelement from lane 0 or lane equivalent to 0. then the extractelement can be merged with fmul in the backend and it incurs no cost. e.g. define double @foo(<2 x double> %a) { %1 = extractelement <2 x double> %a, i32 0 %2 = extractelement <2 x double> %a, i32 1 %res = fmul double %1, %2 ret double %res } %2 and %res can be merged in the backend to generate: fmul d0, d0, v0.d[1] The change was tested with SPEC FP(C/C++) on Neoverse-v2. Compile time impact: None Performance impact: Observing 1.3-1.7% uplift on lbm benchmark with -flto depending upon the config.
1 parent 924a64a commit 08e2879

File tree

9 files changed

+282
-74
lines changed

9 files changed

+282
-74
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#define LLVM_ANALYSIS_TARGETTRANSFORMINFO_H
2323

2424
#include "llvm/ADT/APInt.h"
25+
#include "llvm/ADT/ArrayRef.h"
2526
#include "llvm/IR/FMF.h"
2627
#include "llvm/IR/InstrTypes.h"
2728
#include "llvm/IR/PassManager.h"
@@ -1392,6 +1393,16 @@ class TargetTransformInfo {
13921393
unsigned Index = -1, Value *Op0 = nullptr,
13931394
Value *Op1 = nullptr) const;
13941395

1396+
/// \return The expected cost of vector Insert and Extract.
1397+
/// Use -1 to indicate that there is no information on the index value.
1398+
/// This is used when the instruction is not available; a typical use
1399+
/// case is to provision the cost of vectorization/scalarization in
1400+
/// vectorizer passes.
1401+
InstructionCost getVectorInstrCost(
1402+
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
1403+
Value *Scalar,
1404+
ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) const;
1405+
13951406
/// \return The expected cost of vector Insert and Extract.
13961407
/// This is used when instruction is available, and implementation
13971408
/// asserts 'I' is not nullptr.
@@ -2062,6 +2073,12 @@ class TargetTransformInfo::Concept {
20622073
TTI::TargetCostKind CostKind,
20632074
unsigned Index, Value *Op0,
20642075
Value *Op1) = 0;
2076+
2077+
virtual InstructionCost getVectorInstrCost(
2078+
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
2079+
Value *Scalar,
2080+
ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) = 0;
2081+
20652082
virtual InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
20662083
TTI::TargetCostKind CostKind,
20672084
unsigned Index) = 0;
@@ -2726,6 +2743,13 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
27262743
Value *Op1) override {
27272744
return Impl.getVectorInstrCost(Opcode, Val, CostKind, Index, Op0, Op1);
27282745
}
2746+
InstructionCost getVectorInstrCost(
2747+
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
2748+
Value *Scalar,
2749+
ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) override {
2750+
return Impl.getVectorInstrCost(Opcode, Val, CostKind, Index, Scalar,
2751+
ScalarUserAndIdx);
2752+
}
27292753
InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
27302754
TTI::TargetCostKind CostKind,
27312755
unsigned Index) override {

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,13 @@ class TargetTransformInfoImplBase {
683683
return 1;
684684
}
685685

686+
InstructionCost getVectorInstrCost(
687+
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
688+
Value *Scalar,
689+
ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) const {
690+
return 1;
691+
}
692+
686693
InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
687694
TTI::TargetCostKind CostKind,
688695
unsigned Index) const {

llvm/include/llvm/CodeGen/BasicTTIImpl.h

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#define LLVM_CODEGEN_BASICTTIIMPL_H
1818

1919
#include "llvm/ADT/APInt.h"
20-
#include "llvm/ADT/ArrayRef.h"
2120
#include "llvm/ADT/BitVector.h"
2221
#include "llvm/ADT/SmallPtrSet.h"
2322
#include "llvm/ADT/SmallVector.h"
@@ -1277,12 +1276,20 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
12771276
return 1;
12781277
}
12791278

1280-
InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
1281-
TTI::TargetCostKind CostKind,
1282-
unsigned Index, Value *Op0, Value *Op1) {
1279+
virtual InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
1280+
TTI::TargetCostKind CostKind,
1281+
unsigned Index, Value *Op0,
1282+
Value *Op1) {
12831283
return getRegUsageForType(Val->getScalarType());
12841284
}
12851285

1286+
InstructionCost getVectorInstrCost(
1287+
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
1288+
Value *Scalar,
1289+
ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) {
1290+
return getVectorInstrCost(Opcode, Val, CostKind, Index, nullptr, nullptr);
1291+
}
1292+
12861293
InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
12871294
TTI::TargetCostKind CostKind,
12881295
unsigned Index) {

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,6 +1037,19 @@ InstructionCost TargetTransformInfo::getVectorInstrCost(
10371037
return Cost;
10381038
}
10391039

1040+
InstructionCost TargetTransformInfo::getVectorInstrCost(
1041+
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
1042+
Value *Scalar,
1043+
ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) const {
1044+
// FIXME: Assert that Opcode is either InsertElement or ExtractElement.
1045+
// This is mentioned in the interface description and respected by all
1046+
// callers, but never asserted upon.
1047+
InstructionCost Cost = TTIImpl->getVectorInstrCost(
1048+
Opcode, Val, CostKind, Index, Scalar, ScalarUserAndIdx);
1049+
assert(Cost >= 0 && "TTI should not produce negative costs!");
1050+
return Cost;
1051+
}
1052+
10401053
InstructionCost
10411054
TargetTransformInfo::getVectorInstrCost(const Instruction &I, Type *Val,
10421055
TTI::TargetCostKind CostKind,

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 167 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,28 @@
1010
#include "AArch64ExpandImm.h"
1111
#include "AArch64PerfectShuffle.h"
1212
#include "MCTargetDesc/AArch64AddressingModes.h"
13+
#include "llvm/ADT/STLExtras.h"
1314
#include "llvm/Analysis/IVDescriptors.h"
1415
#include "llvm/Analysis/LoopInfo.h"
1516
#include "llvm/Analysis/TargetTransformInfo.h"
1617
#include "llvm/CodeGen/BasicTTIImpl.h"
1718
#include "llvm/CodeGen/CostTable.h"
1819
#include "llvm/CodeGen/TargetLowering.h"
20+
#include "llvm/IR/DerivedTypes.h"
21+
#include "llvm/IR/InstrTypes.h"
22+
#include "llvm/IR/Instruction.h"
23+
#include "llvm/IR/Instructions.h"
1924
#include "llvm/IR/IntrinsicInst.h"
2025
#include "llvm/IR/Intrinsics.h"
2126
#include "llvm/IR/IntrinsicsAArch64.h"
2227
#include "llvm/IR/PatternMatch.h"
28+
#include "llvm/IR/User.h"
29+
#include "llvm/Support/Casting.h"
2330
#include "llvm/Support/Debug.h"
2431
#include "llvm/Transforms/InstCombine/InstCombiner.h"
2532
#include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h"
2633
#include <algorithm>
34+
//#include <cassert>
2735
#include <optional>
2836
using namespace llvm;
2937
using namespace llvm::PatternMatch;
@@ -3145,12 +3153,16 @@ InstructionCost AArch64TTIImpl::getCFInstrCost(unsigned Opcode,
31453153
return 0;
31463154
}
31473155

3148-
InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(const Instruction *I,
3149-
Type *Val,
3150-
unsigned Index,
3151-
bool HasRealUse) {
3156+
InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(
3157+
std::variant<const Instruction *, const unsigned> InstOrOpcode, Type *Val,
3158+
unsigned Index, bool HasRealUse, Value *Scalar,
3159+
ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) {
31523160
assert(Val->isVectorTy() && "This must be a vector type");
31533161

3162+
const auto *I = (std::holds_alternative<const Instruction *>(InstOrOpcode)
3163+
? get<const Instruction *>(InstOrOpcode)
3164+
: nullptr);
3165+
31543166
if (Index != -1U) {
31553167
// Legalize the type.
31563168
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Val);
@@ -3194,6 +3206,149 @@ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(const Instruction *I,
31943206
// compile-time considerations.
31953207
}
31963208

3209+
// In case of Neon, if there exists extractelement from lane != 0 such that
3210+
// 1. extractelement does not necessitate a move from vector_reg -> GPR.
3211+
// 2. extractelement result feeds into fmul.
3212+
// 3. Other operand of fmul is a scalar or extractelement from lane 0 or lane
3213+
// equivalent to 0.
3214+
// then the extractelement can be merged with fmul in the backend and it
3215+
// incurs no cost.
3216+
// e.g.
3217+
// define double @foo(<2 x double> %a) {
3218+
// %1 = extractelement <2 x double> %a, i32 0
3219+
// %2 = extractelement <2 x double> %a, i32 1
3220+
// %res = fmul double %1, %2
3221+
// ret double %res
3222+
// }
3223+
// %2 and %res can be merged in the backend to generate fmul v0, v0, v1.d[1]
3224+
auto ExtractCanFuseWithFmul = [&]() {
3225+
// We bail out if the extract is from lane 0.
3226+
if (Index == 0)
3227+
return false;
3228+
3229+
// Check if the scalar element type of the vector operand of ExtractElement
3230+
// instruction is one of the allowed types.
3231+
auto IsAllowedScalarTy = [&](const Type *T) {
3232+
return T->isFloatTy() || T->isDoubleTy() ||
3233+
(T->isHalfTy() && ST->hasFullFP16());
3234+
};
3235+
3236+
// Check if the extractelement user is scalar fmul.
3237+
auto IsUserFMulScalarTy = [](const Value *EEUser) {
3238+
// Check if the user is scalar fmul.
3239+
const auto *BO = dyn_cast_if_present<BinaryOperator>(EEUser);
3240+
return BO && BO->getOpcode() == BinaryOperator::FMul &&
3241+
!BO->getType()->isVectorTy();
3242+
};
3243+
3244+
// InstCombine combines fmul with fadd/fsub. Hence, extractelement fusion
3245+
// with fmul does not happen.
3246+
auto IsFMulUserFAddFSub = [](const Value *FMul) {
3247+
return any_of(FMul->users(), [](const User *U) {
3248+
const auto *BO = dyn_cast_if_present<BinaryOperator>(U);
3249+
return (BO && (BO->getOpcode() == BinaryOperator::FAdd ||
3250+
BO->getOpcode() == BinaryOperator::FSub));
3251+
});
3252+
};
3253+
3254+
// Check if the type constraints on input vector type and result scalar type
3255+
// of extractelement instruction are satisfied.
3256+
auto TypeConstraintsOnEESatisfied =
3257+
[&IsAllowedScalarTy](const Type *VectorTy, const Type *ScalarTy) {
3258+
return isa<FixedVectorType>(VectorTy) && IsAllowedScalarTy(ScalarTy);
3259+
};
3260+
3261+
// Check if the extract index is from lane 0 or lane equivalent to 0 for a
3262+
// certain scalar type and a certain vector register width.
3263+
auto IsExtractLaneEquivalentToZero = [&](const unsigned &Idx,
3264+
const unsigned &EltSz) {
3265+
auto RegWidth =
3266+
getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
3267+
.getFixedValue();
3268+
return (Idx == 0 || (Idx * EltSz) % RegWidth == 0);
3269+
};
3270+
3271+
if (std::holds_alternative<const unsigned>(InstOrOpcode)) {
3272+
if (!TypeConstraintsOnEESatisfied(Val, Val->getScalarType()))
3273+
return false;
3274+
3275+
for (auto &RefT : ScalarUserAndIdx) {
3276+
Value *RefS = get<0>(RefT);
3277+
User *RefU = get<1>(RefT);
3278+
const int &RefL = get<2>(RefT);
3279+
3280+
// Analayze all the users which have same scalar/index as Scalar/Index.
3281+
if (RefS != Scalar || RefL != Index)
3282+
continue;
3283+
3284+
// Check if the user of {Scalar, Index} pair is fmul user.
3285+
if (!IsUserFMulScalarTy(RefU) || IsFMulUserFAddFSub(RefU))
3286+
return false;
3287+
3288+
// For RefU, check if the other operand is extract from the same SLP
3289+
// tree. If not, we bail out since we can't analyze extracts from other
3290+
// SLP tree.
3291+
unsigned NumExtractEltsIntoUser = 0;
3292+
for (auto &CmpT : ScalarUserAndIdx) {
3293+
User *CmpU = get<1>(CmpT);
3294+
if (CmpT == RefT || CmpU != RefU)
3295+
continue;
3296+
Value *CmpS = get<0>(CmpT);
3297+
++NumExtractEltsIntoUser;
3298+
const int &CmpL = get<2>(CmpT);
3299+
if (!IsExtractLaneEquivalentToZero(CmpL, Val->getScalarSizeInBits()))
3300+
return false;
3301+
}
3302+
// We know this is fmul user with just 2 operands, one being RefT. If we
3303+
// can't find CmpT, as the other operand, then bail out.
3304+
if (NumExtractEltsIntoUser != 1)
3305+
return false;
3306+
}
3307+
} else {
3308+
const auto *EE = cast<ExtractElementInst>(I);
3309+
3310+
const auto *IdxOp = dyn_cast<ConstantInt>(EE->getIndexOperand());
3311+
if (!IdxOp)
3312+
return false;
3313+
3314+
if (!TypeConstraintsOnEESatisfied(EE->getVectorOperand()->getType(),
3315+
EE->getType()))
3316+
return false;
3317+
3318+
return !EE->users().empty() && all_of(EE->users(), [&](const User *U) {
3319+
if (!IsUserFMulScalarTy(U) || IsFMulUserFAddFSub(U))
3320+
return false;
3321+
3322+
// Check if the other operand of extractelement is also extractelement
3323+
// from lane equivalent to 0.
3324+
const auto *BO = cast<BinaryOperator>(U);
3325+
const auto *OtherEE = dyn_cast<ExtractElementInst>(
3326+
BO->getOperand(0) == EE ? BO->getOperand(1) : BO->getOperand(0));
3327+
if (OtherEE) {
3328+
const auto *IdxOp = dyn_cast<ConstantInt>(OtherEE->getIndexOperand());
3329+
if (!IdxOp)
3330+
return false;
3331+
return IsExtractLaneEquivalentToZero(
3332+
cast<ConstantInt>(OtherEE->getIndexOperand())
3333+
->getValue()
3334+
.getZExtValue(),
3335+
OtherEE->getType()->getScalarSizeInBits());
3336+
}
3337+
return true;
3338+
});
3339+
}
3340+
return true;
3341+
};
3342+
3343+
if (std::holds_alternative<const unsigned>(InstOrOpcode)) {
3344+
const unsigned &Opcode = get<const unsigned>(InstOrOpcode);
3345+
if (Opcode == Instruction::ExtractElement && ExtractCanFuseWithFmul())
3346+
return 0;
3347+
} else if (I && I->getOpcode() == Instruction::ExtractElement &&
3348+
ExtractCanFuseWithFmul()) {
3349+
return 0;
3350+
}
3351+
31973352
// All other insert/extracts cost this much.
31983353
return ST->getVectorInsertExtractBaseCost();
31993354
}
@@ -3207,6 +3362,14 @@ InstructionCost AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
32073362
return getVectorInstrCostHelper(nullptr, Val, Index, HasRealUse);
32083363
}
32093364

3365+
InstructionCost AArch64TTIImpl::getVectorInstrCost(
3366+
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
3367+
Value *Scalar,
3368+
ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) {
3369+
return getVectorInstrCostHelper(Opcode, Val, Index, false, Scalar,
3370+
ScalarUserAndIdx);
3371+
}
3372+
32103373
InstructionCost AArch64TTIImpl::getVectorInstrCost(const Instruction &I,
32113374
Type *Val,
32123375
TTI::TargetCostKind CostKind,

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#include "AArch64.h"
2020
#include "AArch64Subtarget.h"
2121
#include "AArch64TargetMachine.h"
22-
#include "llvm/ADT/ArrayRef.h"
22+
#include "llvm/ADT/SmallVector.h"
2323
#include "llvm/Analysis/TargetTransformInfo.h"
2424
#include "llvm/CodeGen/BasicTTIImpl.h"
2525
#include "llvm/IR/Function.h"
@@ -66,8 +66,11 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
6666
// 'Val' and 'Index' are forwarded from 'getVectorInstrCost'; 'HasRealUse'
6767
// indicates whether the vector instruction is available in the input IR or
6868
// just imaginary in vectorizer passes.
69-
InstructionCost getVectorInstrCostHelper(const Instruction *I, Type *Val,
70-
unsigned Index, bool HasRealUse);
69+
InstructionCost getVectorInstrCostHelper(
70+
std::variant<const Instruction *, const unsigned> InstOrOpcode, Type *Val,
71+
unsigned Index, bool HasRealUse, Value *Scalar = nullptr,
72+
ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx =
73+
SmallVector<std::tuple<Value *, User *, int>, 0>());
7174

7275
public:
7376
explicit AArch64TTIImpl(const AArch64TargetMachine *TM, const Function &F)
@@ -185,6 +188,12 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
185188
InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
186189
TTI::TargetCostKind CostKind,
187190
unsigned Index, Value *Op0, Value *Op1);
191+
192+
InstructionCost getVectorInstrCost(
193+
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
194+
Value *Scalar,
195+
ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx);
196+
188197
InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
189198
TTI::TargetCostKind CostKind,
190199
unsigned Index);

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11633,6 +11633,13 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
1163311633
std::optional<DenseMap<Value *, unsigned>> ValueToExtUses;
1163411634
DenseMap<const TreeEntry *, DenseSet<Value *>> ExtractsCount;
1163511635
SmallPtrSet<Value *, 4> ScalarOpsFromCasts;
11636+
// Keep track {Scalar, Index, User} tuple.
11637+
// On AArch64, this helps in fusing a mov instruction, associated with
11638+
// extractelement, with fmul in the backend so that extractelement is free.
11639+
SmallVector<std::tuple<Value *, User *, int>, 4> ScalarUserAndIdx;
11640+
for (ExternalUser &EU : ExternalUses) {
11641+
ScalarUserAndIdx.emplace_back(EU.Scalar, EU.User, EU.Lane);
11642+
}
1163611643
for (ExternalUser &EU : ExternalUses) {
1163711644
// Uses by ephemeral values are free (because the ephemeral value will be
1163811645
// removed prior to code generation, and so the extraction will be
@@ -11739,8 +11746,9 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
1173911746
ExtraCost = TTI->getExtractWithExtendCost(Extend, EU.Scalar->getType(),
1174011747
VecTy, EU.Lane);
1174111748
} else {
11742-
ExtraCost = TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy,
11743-
CostKind, EU.Lane);
11749+
ExtraCost =
11750+
TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, CostKind,
11751+
EU.Lane, EU.Scalar, ScalarUserAndIdx);
1174411752
}
1174511753
// Leave the scalar instructions as is if they are cheaper than extracts.
1174611754
if (Entry->Idx != 0 || Entry->getOpcode() == Instruction::GetElementPtr ||

0 commit comments

Comments
 (0)