Skip to content

Commit 72e6397

Browse files
committed
[CostModel][AArch64] Make extractelement, with fmul user, free whenever possible
In case of Neon, if there exists extractelement from lane != 0 such that 1. extractelement does not necessitate a move from vector_reg -> GPR. 2. extractelement result feeds into fmul. 3. Other operand of fmul is a scalar or extractelement from lane 0 or lane equivalent to 0. then the extractelement can be merged with fmul in the backend and it incurs no cost. e.g. define double @foo(<2 x double> %a) { %1 = extractelement <2 x double> %a, i32 0 %2 = extractelement <2 x double> %a, i32 1 %res = fmul double %1, %2 ret double %res } %2 and %res can be merged in the backend to generate: fmul d0, d0, v0.d[1] The change was tested with SPEC FP(C/C++) on Neoverse-v2. Compile time impact: None Performance impact: Observing 1.3-1.7% uplift on lbm benchmark with -flto depending upon the config.
1 parent 924a64a commit 72e6397

File tree

9 files changed

+275
-74
lines changed

9 files changed

+275
-74
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#define LLVM_ANALYSIS_TARGETTRANSFORMINFO_H
2323

2424
#include "llvm/ADT/APInt.h"
25+
#include "llvm/ADT/ArrayRef.h"
2526
#include "llvm/IR/FMF.h"
2627
#include "llvm/IR/InstrTypes.h"
2728
#include "llvm/IR/PassManager.h"
@@ -1392,6 +1393,16 @@ class TargetTransformInfo {
13921393
unsigned Index = -1, Value *Op0 = nullptr,
13931394
Value *Op1 = nullptr) const;
13941395

1396+
/// \return The expected cost of vector Insert and Extract.
1397+
/// Use -1 to indicate that there is no information on the index value.
1398+
/// This is used when the instruction is not available; a typical use
1399+
/// case is to provision the cost of vectorization/scalarization in
1400+
/// vectorizer passes.
1401+
InstructionCost getVectorInstrCost(
1402+
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
1403+
Value *Scalar,
1404+
ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) const;
1405+
13951406
/// \return The expected cost of vector Insert and Extract.
13961407
/// This is used when instruction is available, and implementation
13971408
/// asserts 'I' is not nullptr.
@@ -2062,6 +2073,12 @@ class TargetTransformInfo::Concept {
20622073
TTI::TargetCostKind CostKind,
20632074
unsigned Index, Value *Op0,
20642075
Value *Op1) = 0;
2076+
2077+
virtual InstructionCost getVectorInstrCost(
2078+
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
2079+
Value *Scalar,
2080+
ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) = 0;
2081+
20652082
virtual InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
20662083
TTI::TargetCostKind CostKind,
20672084
unsigned Index) = 0;
@@ -2726,6 +2743,13 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
27262743
Value *Op1) override {
27272744
return Impl.getVectorInstrCost(Opcode, Val, CostKind, Index, Op0, Op1);
27282745
}
2746+
InstructionCost getVectorInstrCost(
2747+
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
2748+
Value *Scalar,
2749+
ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) override {
2750+
return Impl.getVectorInstrCost(Opcode, Val, CostKind, Index, Scalar,
2751+
ScalarUserAndIdx);
2752+
}
27292753
InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
27302754
TTI::TargetCostKind CostKind,
27312755
unsigned Index) override {

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,13 @@ class TargetTransformInfoImplBase {
683683
return 1;
684684
}
685685

686+
InstructionCost getVectorInstrCost(
687+
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
688+
Value *Scalar,
689+
ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) const {
690+
return 1;
691+
}
692+
686693
InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
687694
TTI::TargetCostKind CostKind,
688695
unsigned Index) const {

llvm/include/llvm/CodeGen/BasicTTIImpl.h

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#define LLVM_CODEGEN_BASICTTIIMPL_H
1818

1919
#include "llvm/ADT/APInt.h"
20-
#include "llvm/ADT/ArrayRef.h"
2120
#include "llvm/ADT/BitVector.h"
2221
#include "llvm/ADT/SmallPtrSet.h"
2322
#include "llvm/ADT/SmallVector.h"
@@ -1277,12 +1276,20 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
12771276
return 1;
12781277
}
12791278

1280-
InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
1281-
TTI::TargetCostKind CostKind,
1282-
unsigned Index, Value *Op0, Value *Op1) {
1279+
virtual InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
1280+
TTI::TargetCostKind CostKind,
1281+
unsigned Index, Value *Op0,
1282+
Value *Op1) {
12831283
return getRegUsageForType(Val->getScalarType());
12841284
}
12851285

1286+
InstructionCost getVectorInstrCost(
1287+
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
1288+
Value *Scalar,
1289+
ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) {
1290+
return getVectorInstrCost(Opcode, Val, CostKind, Index, nullptr, nullptr);
1291+
}
1292+
12861293
InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
12871294
TTI::TargetCostKind CostKind,
12881295
unsigned Index) {

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,6 +1037,19 @@ InstructionCost TargetTransformInfo::getVectorInstrCost(
10371037
return Cost;
10381038
}
10391039

1040+
InstructionCost TargetTransformInfo::getVectorInstrCost(
1041+
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
1042+
Value *Scalar,
1043+
ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) const {
1044+
// FIXME: Assert that Opcode is either InsertElement or ExtractElement.
1045+
// This is mentioned in the interface description and respected by all
1046+
// callers, but never asserted upon.
1047+
InstructionCost Cost = TTIImpl->getVectorInstrCost(
1048+
Opcode, Val, CostKind, Index, Scalar, ScalarUserAndIdx);
1049+
assert(Cost >= 0 && "TTI should not produce negative costs!");
1050+
return Cost;
1051+
}
1052+
10401053
InstructionCost
10411054
TargetTransformInfo::getVectorInstrCost(const Instruction &I, Type *Val,
10421055
TTI::TargetCostKind CostKind,

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 162 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,29 @@
1010
#include "AArch64ExpandImm.h"
1111
#include "AArch64PerfectShuffle.h"
1212
#include "MCTargetDesc/AArch64AddressingModes.h"
13+
#include "llvm/ADT/DenseMap.h"
14+
#include "llvm/ADT/STLExtras.h"
1315
#include "llvm/Analysis/IVDescriptors.h"
1416
#include "llvm/Analysis/LoopInfo.h"
1517
#include "llvm/Analysis/TargetTransformInfo.h"
1618
#include "llvm/CodeGen/BasicTTIImpl.h"
1719
#include "llvm/CodeGen/CostTable.h"
1820
#include "llvm/CodeGen/TargetLowering.h"
21+
#include "llvm/IR/DerivedTypes.h"
22+
#include "llvm/IR/InstrTypes.h"
23+
#include "llvm/IR/Instruction.h"
24+
#include "llvm/IR/Instructions.h"
1925
#include "llvm/IR/IntrinsicInst.h"
2026
#include "llvm/IR/Intrinsics.h"
2127
#include "llvm/IR/IntrinsicsAArch64.h"
2228
#include "llvm/IR/PatternMatch.h"
29+
#include "llvm/IR/User.h"
30+
#include "llvm/Support/Casting.h"
2331
#include "llvm/Support/Debug.h"
2432
#include "llvm/Transforms/InstCombine/InstCombiner.h"
2533
#include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h"
2634
#include <algorithm>
35+
//#include <cassert>
2736
#include <optional>
2837
using namespace llvm;
2938
using namespace llvm::PatternMatch;
@@ -3145,12 +3154,16 @@ InstructionCost AArch64TTIImpl::getCFInstrCost(unsigned Opcode,
31453154
return 0;
31463155
}
31473156

3148-
InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(const Instruction *I,
3149-
Type *Val,
3150-
unsigned Index,
3151-
bool HasRealUse) {
3157+
InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(
3158+
std::variant<const Instruction *, const unsigned> InstOrOpcode, Type *Val,
3159+
unsigned Index, bool HasRealUse, Value *Scalar,
3160+
ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) {
31523161
assert(Val->isVectorTy() && "This must be a vector type");
31533162

3163+
const auto *I = (std::holds_alternative<const Instruction *>(InstOrOpcode)
3164+
? get<const Instruction *>(InstOrOpcode)
3165+
: nullptr);
3166+
31543167
if (Index != -1U) {
31553168
// Legalize the type.
31563169
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Val);
@@ -3194,6 +3207,143 @@ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(const Instruction *I,
31943207
// compile-time considerations.
31953208
}
31963209

3210+
// In case of Neon, if there exists extractelement from lane != 0 such that
3211+
// 1. extractelement does not necessitate a move from vector_reg -> GPR.
3212+
// 2. extractelement result feeds into fmul.
3213+
// 3. Other operand of fmul is an extractelement from lane 0 or lane
3214+
// equivalent to 0.
3215+
// then the extractelement can be merged with fmul in the backend and it
3216+
// incurs no cost.
3217+
// e.g.
3218+
// define double @foo(<2 x double> %a) {
3219+
// %1 = extractelement <2 x double> %a, i32 0
3220+
// %2 = extractelement <2 x double> %a, i32 1
3221+
// %res = fmul double %1, %2
3222+
// ret double %res
3223+
// }
3224+
// %2 and %res can be merged in the backend to generate fmul d0, d0, v1.d[1]
3225+
auto ExtractCanFuseWithFmul = [&]() {
3226+
// We bail out if the extract is from lane 0.
3227+
if (Index == 0)
3228+
return false;
3229+
3230+
// Check if the scalar element type of the vector operand of ExtractElement
3231+
// instruction is one of the allowed types.
3232+
auto IsAllowedScalarTy = [&](const Type *T) {
3233+
return T->isFloatTy() || T->isDoubleTy() ||
3234+
(T->isHalfTy() && ST->hasFullFP16());
3235+
};
3236+
3237+
// Check if the extractelement user is scalar fmul.
3238+
auto IsUserFMulScalarTy = [](const Value *EEUser) {
3239+
// Check if the user is scalar fmul.
3240+
const auto *BO = dyn_cast_if_present<BinaryOperator>(EEUser);
3241+
return BO && BO->getOpcode() == BinaryOperator::FMul &&
3242+
!BO->getType()->isVectorTy();
3243+
};
3244+
3245+
// InstCombine combines fmul with fadd/fsub. Hence, extractelement fusion
3246+
// with fmul does not happen.
3247+
auto IsFMulUserFAddFSub = [](const Value *FMul) {
3248+
return any_of(FMul->users(), [](const User *U) {
3249+
const auto *BO = dyn_cast_if_present<BinaryOperator>(U);
3250+
return (BO && (BO->getOpcode() == BinaryOperator::FAdd ||
3251+
BO->getOpcode() == BinaryOperator::FSub));
3252+
});
3253+
};
3254+
3255+
// Check if the type constraints on input vector type and result scalar type
3256+
// of extractelement instruction are satisfied.
3257+
auto TypeConstraintsOnEESatisfied =
3258+
[&IsAllowedScalarTy](const Type *VectorTy, const Type *ScalarTy) {
3259+
return isa<FixedVectorType>(VectorTy) && IsAllowedScalarTy(ScalarTy);
3260+
};
3261+
3262+
// Check if the extract index is from lane 0 or lane equivalent to 0 for a
3263+
// certain scalar type and a certain vector register width.
3264+
auto IsExtractLaneEquivalentToZero = [&](const unsigned &Idx,
3265+
const unsigned &EltSz) {
3266+
auto RegWidth =
3267+
getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
3268+
.getFixedValue();
3269+
return (Idx == 0 || (Idx * EltSz) % RegWidth == 0);
3270+
};
3271+
3272+
if (std::holds_alternative<const unsigned>(InstOrOpcode)) {
3273+
if (!TypeConstraintsOnEESatisfied(Val, Val->getScalarType()))
3274+
return false;
3275+
3276+
DenseMap<User *, unsigned> UserToExtractIdx;
3277+
for (auto *U : Scalar->users()) {
3278+
if (!IsUserFMulScalarTy(U) || IsFMulUserFAddFSub(U))
3279+
return false;
3280+
// Recording entry for the user is important. Index value is not
3281+
// important.
3282+
UserToExtractIdx[U];
3283+
}
3284+
for (auto &[S, U, L] : ScalarUserAndIdx) {
3285+
for (auto *U : S->users()) {
3286+
if (UserToExtractIdx.find(U) != UserToExtractIdx.end()) {
3287+
auto *FMul = cast<BinaryOperator>(U);
3288+
auto *Op0 = FMul->getOperand(0);
3289+
auto *Op1 = FMul->getOperand(1);
3290+
if ((Op0 == S && Op1 == S) || (Op0 != S) || (Op1 != S)) {
3291+
UserToExtractIdx[U] = L;
3292+
break;
3293+
}
3294+
}
3295+
}
3296+
}
3297+
for (auto &[U, L] : UserToExtractIdx) {
3298+
if (!IsExtractLaneEquivalentToZero(Index, Val->getScalarSizeInBits()) &&
3299+
!IsExtractLaneEquivalentToZero(L, Val->getScalarSizeInBits()))
3300+
return false;
3301+
}
3302+
} else {
3303+
const auto *EE = cast<ExtractElementInst>(I);
3304+
3305+
const auto *IdxOp = dyn_cast<ConstantInt>(EE->getIndexOperand());
3306+
if (!IdxOp)
3307+
return false;
3308+
3309+
if (!TypeConstraintsOnEESatisfied(EE->getVectorOperand()->getType(),
3310+
EE->getType()))
3311+
return false;
3312+
3313+
return !EE->users().empty() && all_of(EE->users(), [&](const User *U) {
3314+
if (!IsUserFMulScalarTy(U) || IsFMulUserFAddFSub(U))
3315+
return false;
3316+
3317+
// Check if the other operand of extractelement is also extractelement
3318+
// from lane equivalent to 0.
3319+
const auto *BO = cast<BinaryOperator>(U);
3320+
const auto *OtherEE = dyn_cast<ExtractElementInst>(
3321+
BO->getOperand(0) == EE ? BO->getOperand(1) : BO->getOperand(0));
3322+
if (OtherEE) {
3323+
const auto *IdxOp = dyn_cast<ConstantInt>(OtherEE->getIndexOperand());
3324+
if (!IdxOp)
3325+
return false;
3326+
return IsExtractLaneEquivalentToZero(
3327+
cast<ConstantInt>(OtherEE->getIndexOperand())
3328+
->getValue()
3329+
.getZExtValue(),
3330+
OtherEE->getType()->getScalarSizeInBits());
3331+
}
3332+
return true;
3333+
});
3334+
}
3335+
return true;
3336+
};
3337+
3338+
if (std::holds_alternative<const unsigned>(InstOrOpcode)) {
3339+
const unsigned &Opcode = get<const unsigned>(InstOrOpcode);
3340+
if (Opcode == Instruction::ExtractElement && ExtractCanFuseWithFmul())
3341+
return 0;
3342+
} else if (I && I->getOpcode() == Instruction::ExtractElement &&
3343+
ExtractCanFuseWithFmul()) {
3344+
return 0;
3345+
}
3346+
31973347
// All other insert/extracts cost this much.
31983348
return ST->getVectorInsertExtractBaseCost();
31993349
}
@@ -3207,6 +3357,14 @@ InstructionCost AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
32073357
return getVectorInstrCostHelper(nullptr, Val, Index, HasRealUse);
32083358
}
32093359

3360+
InstructionCost AArch64TTIImpl::getVectorInstrCost(
3361+
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
3362+
Value *Scalar,
3363+
ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) {
3364+
return getVectorInstrCostHelper(Opcode, Val, Index, false, Scalar,
3365+
ScalarUserAndIdx);
3366+
}
3367+
32103368
InstructionCost AArch64TTIImpl::getVectorInstrCost(const Instruction &I,
32113369
Type *Val,
32123370
TTI::TargetCostKind CostKind,

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
#include "AArch64.h"
2020
#include "AArch64Subtarget.h"
2121
#include "AArch64TargetMachine.h"
22-
#include "llvm/ADT/ArrayRef.h"
2322
#include "llvm/Analysis/TargetTransformInfo.h"
2423
#include "llvm/CodeGen/BasicTTIImpl.h"
2524
#include "llvm/IR/Function.h"
@@ -66,8 +65,10 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
6665
// 'Val' and 'Index' are forwarded from 'getVectorInstrCost'; 'HasRealUse'
6766
// indicates whether the vector instruction is available in the input IR or
6867
// just imaginary in vectorizer passes.
69-
InstructionCost getVectorInstrCostHelper(const Instruction *I, Type *Val,
70-
unsigned Index, bool HasRealUse);
68+
InstructionCost getVectorInstrCostHelper(
69+
std::variant<const Instruction *, const unsigned> InstOrOpcode, Type *Val,
70+
unsigned Index, bool HasRealUse, Value *Scalar = nullptr,
71+
ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx = {});
7172

7273
public:
7374
explicit AArch64TTIImpl(const AArch64TargetMachine *TM, const Function &F)
@@ -185,6 +186,12 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
185186
InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
186187
TTI::TargetCostKind CostKind,
187188
unsigned Index, Value *Op0, Value *Op1);
189+
190+
InstructionCost getVectorInstrCost(
191+
unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
192+
Value *Scalar,
193+
ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx);
194+
188195
InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
189196
TTI::TargetCostKind CostKind,
190197
unsigned Index);

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11633,6 +11633,13 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
1163311633
std::optional<DenseMap<Value *, unsigned>> ValueToExtUses;
1163411634
DenseMap<const TreeEntry *, DenseSet<Value *>> ExtractsCount;
1163511635
SmallPtrSet<Value *, 4> ScalarOpsFromCasts;
11636+
// Keep track {Scalar, Index, User} tuple.
11637+
// On AArch64, this helps in fusing a mov instruction, associated with
11638+
// extractelement, with fmul in the backend so that extractelement is free.
11639+
SmallVector<std::tuple<Value *, User *, int>, 4> ScalarUserAndIdx;
11640+
for (ExternalUser &EU : ExternalUses) {
11641+
ScalarUserAndIdx.emplace_back(EU.Scalar, EU.User, EU.Lane);
11642+
}
1163611643
for (ExternalUser &EU : ExternalUses) {
1163711644
// Uses by ephemeral values are free (because the ephemeral value will be
1163811645
// removed prior to code generation, and so the extraction will be
@@ -11739,8 +11746,9 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
1173911746
ExtraCost = TTI->getExtractWithExtendCost(Extend, EU.Scalar->getType(),
1174011747
VecTy, EU.Lane);
1174111748
} else {
11742-
ExtraCost = TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy,
11743-
CostKind, EU.Lane);
11749+
ExtraCost =
11750+
TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, CostKind,
11751+
EU.Lane, EU.Scalar, ScalarUserAndIdx);
1174411752
}
1174511753
// Leave the scalar instructions as is if they are cheaper than extracts.
1174611754
if (Entry->Idx != 0 || Entry->getOpcode() == Instruction::GetElementPtr ||

0 commit comments

Comments
 (0)