Skip to content

Commit 0ad6be1

Browse files
authored
[SLPVectorizer, TargetTransformInfo, SystemZ] Improve SLP getGatherCost(). (#112491)
As vector element loads are free on SystemZ, this patch improves the cost computation in getGatherCost() to reflect this. getScalarizationOverhead() gets an optional parameter which can hold the actual Values so that they in turn can be passed (by BasicTTIImpl) to getVectorInstrCost(). SystemZTTIImpl::getVectorInstrCost() will now recognize a LoadInst and typically return a 0 cost for it, with some exceptions.
1 parent cbf495f commit 0ad6be1

14 files changed

+209
-88
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -909,11 +909,13 @@ class TargetTransformInfo {
909909

910910
/// Estimate the overhead of scalarizing an instruction. Insert and Extract
911911
/// are set if the demanded result elements need to be inserted and/or
912-
/// extracted from vectors.
912+
/// extracted from vectors. The involved values may be passed in VL if
913+
/// Insert is true.
913914
InstructionCost getScalarizationOverhead(VectorType *Ty,
914915
const APInt &DemandedElts,
915916
bool Insert, bool Extract,
916-
TTI::TargetCostKind CostKind) const;
917+
TTI::TargetCostKind CostKind,
918+
ArrayRef<Value *> VL = {}) const;
917919

918920
/// Estimate the overhead of scalarizing an instructions unique
919921
/// non-constant operands. The (potentially vector) types to use for each of
@@ -2001,10 +2003,10 @@ class TargetTransformInfo::Concept {
20012003
unsigned ScalarOpdIdx) = 0;
20022004
virtual bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
20032005
int ScalarOpdIdx) = 0;
2004-
virtual InstructionCost getScalarizationOverhead(VectorType *Ty,
2005-
const APInt &DemandedElts,
2006-
bool Insert, bool Extract,
2007-
TargetCostKind CostKind) = 0;
2006+
virtual InstructionCost
2007+
getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
2008+
bool Insert, bool Extract, TargetCostKind CostKind,
2009+
ArrayRef<Value *> VL = {}) = 0;
20082010
virtual InstructionCost
20092011
getOperandsScalarizationOverhead(ArrayRef<const Value *> Args,
20102012
ArrayRef<Type *> Tys,
@@ -2585,9 +2587,10 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
25852587
InstructionCost getScalarizationOverhead(VectorType *Ty,
25862588
const APInt &DemandedElts,
25872589
bool Insert, bool Extract,
2588-
TargetCostKind CostKind) override {
2590+
TargetCostKind CostKind,
2591+
ArrayRef<Value *> VL = {}) override {
25892592
return Impl.getScalarizationOverhead(Ty, DemandedElts, Insert, Extract,
2590-
CostKind);
2593+
CostKind, VL);
25912594
}
25922595
InstructionCost
25932596
getOperandsScalarizationOverhead(ArrayRef<const Value *> Args,

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,8 @@ class TargetTransformInfoImplBase {
404404
InstructionCost getScalarizationOverhead(VectorType *Ty,
405405
const APInt &DemandedElts,
406406
bool Insert, bool Extract,
407-
TTI::TargetCostKind CostKind) const {
407+
TTI::TargetCostKind CostKind,
408+
ArrayRef<Value *> VL = {}) const {
408409
return 0;
409410
}
410411

llvm/include/llvm/CodeGen/BasicTTIImpl.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -780,24 +780,28 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
780780
InstructionCost getScalarizationOverhead(VectorType *InTy,
781781
const APInt &DemandedElts,
782782
bool Insert, bool Extract,
783-
TTI::TargetCostKind CostKind) {
783+
TTI::TargetCostKind CostKind,
784+
ArrayRef<Value *> VL = {}) {
784785
/// FIXME: a bitfield is not a reasonable abstraction for talking about
785786
/// which elements are needed from a scalable vector
786787
if (isa<ScalableVectorType>(InTy))
787788
return InstructionCost::getInvalid();
788789
auto *Ty = cast<FixedVectorType>(InTy);
789790

790791
assert(DemandedElts.getBitWidth() == Ty->getNumElements() &&
792+
(VL.empty() || VL.size() == Ty->getNumElements()) &&
791793
"Vector size mismatch");
792794

793795
InstructionCost Cost = 0;
794796

795797
for (int i = 0, e = Ty->getNumElements(); i < e; ++i) {
796798
if (!DemandedElts[i])
797799
continue;
798-
if (Insert)
800+
if (Insert) {
801+
Value *InsertedVal = VL.empty() ? nullptr : VL[i];
799802
Cost += thisT()->getVectorInstrCost(Instruction::InsertElement, Ty,
800-
CostKind, i, nullptr, nullptr);
803+
CostKind, i, nullptr, InsertedVal);
804+
}
801805
if (Extract)
802806
Cost += thisT()->getVectorInstrCost(Instruction::ExtractElement, Ty,
803807
CostKind, i, nullptr, nullptr);

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -622,9 +622,9 @@ bool TargetTransformInfo::isVectorIntrinsicWithOverloadTypeAtArg(
622622

623623
InstructionCost TargetTransformInfo::getScalarizationOverhead(
624624
VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
625-
TTI::TargetCostKind CostKind) const {
625+
TTI::TargetCostKind CostKind, ArrayRef<Value *> VL) const {
626626
return TTIImpl->getScalarizationOverhead(Ty, DemandedElts, Insert, Extract,
627-
CostKind);
627+
CostKind, VL);
628628
}
629629

630630
InstructionCost TargetTransformInfo::getOperandsScalarizationOverhead(

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3363,7 +3363,7 @@ InstructionCost AArch64TTIImpl::getVectorInstrCost(const Instruction &I,
33633363

33643364
InstructionCost AArch64TTIImpl::getScalarizationOverhead(
33653365
VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
3366-
TTI::TargetCostKind CostKind) {
3366+
TTI::TargetCostKind CostKind, ArrayRef<Value *> VL) {
33673367
if (isa<ScalableVectorType>(Ty))
33683368
return InstructionCost::getInvalid();
33693369
if (Ty->getElementType()->isFloatingPointTy())

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
423423
InstructionCost getScalarizationOverhead(VectorType *Ty,
424424
const APInt &DemandedElts,
425425
bool Insert, bool Extract,
426-
TTI::TargetCostKind CostKind);
426+
TTI::TargetCostKind CostKind,
427+
ArrayRef<Value *> VL = {});
427428

428429
/// Return the cost of the scaling factor used in the addressing
429430
/// mode represented by AM for this target, for a load/store

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,7 @@ static unsigned isM1OrSmaller(MVT VT) {
669669

670670
InstructionCost RISCVTTIImpl::getScalarizationOverhead(
671671
VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
672-
TTI::TargetCostKind CostKind) {
672+
TTI::TargetCostKind CostKind, ArrayRef<Value *> VL) {
673673
if (isa<ScalableVectorType>(Ty))
674674
return InstructionCost::getInvalid();
675675

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
149149
InstructionCost getScalarizationOverhead(VectorType *Ty,
150150
const APInt &DemandedElts,
151151
bool Insert, bool Extract,
152-
TTI::TargetCostKind CostKind);
152+
TTI::TargetCostKind CostKind,
153+
ArrayRef<Value *> VL = {});
153154

154155
InstructionCost getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
155156
TTI::TargetCostKind CostKind);

llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp

Lines changed: 60 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,42 @@ bool SystemZTTIImpl::hasDivRemOp(Type *DataType, bool IsSigned) {
468468
return (VT.isScalarInteger() && TLI->isTypeLegal(VT));
469469
}
470470

471+
static bool isFreeEltLoad(Value *Op) {
472+
if (isa<LoadInst>(Op) && Op->hasOneUse()) {
473+
const Instruction *UserI = cast<Instruction>(*Op->user_begin());
474+
return !isa<StoreInst>(UserI); // Prefer MVC
475+
}
476+
return false;
477+
}
478+
479+
InstructionCost SystemZTTIImpl::getScalarizationOverhead(
480+
VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
481+
TTI::TargetCostKind CostKind, ArrayRef<Value *> VL) {
482+
unsigned NumElts = cast<FixedVectorType>(Ty)->getNumElements();
483+
InstructionCost Cost = 0;
484+
485+
if (Insert && Ty->isIntOrIntVectorTy(64)) {
486+
// VLVGP will insert two GPRs with one instruction, while VLE will load
487+
// an element directly with no extra cost
488+
assert((VL.empty() || VL.size() == NumElts) &&
489+
"Type does not match the number of values.");
490+
InstructionCost CurrVectorCost = 0;
491+
for (unsigned Idx = 0; Idx < NumElts; ++Idx) {
492+
if (DemandedElts[Idx] && !(VL.size() && isFreeEltLoad(VL[Idx])))
493+
++CurrVectorCost;
494+
if (Idx % 2 == 1) {
495+
Cost += std::min(InstructionCost(1), CurrVectorCost);
496+
CurrVectorCost = 0;
497+
}
498+
}
499+
Insert = false;
500+
}
501+
502+
Cost += BaseT::getScalarizationOverhead(Ty, DemandedElts, Insert, Extract,
503+
CostKind, VL);
504+
return Cost;
505+
}
506+
471507
// Return the bit size for the scalar type or vector element
472508
// type. getScalarSizeInBits() returns 0 for a pointer type.
473509
static unsigned getScalarSizeInBits(Type *Ty) {
@@ -609,7 +645,7 @@ InstructionCost SystemZTTIImpl::getArithmeticInstrCost(
609645
if (DivRemConst) {
610646
SmallVector<Type *> Tys(Args.size(), Ty);
611647
return VF * DivMulSeqCost +
612-
getScalarizationOverhead(VTy, Args, Tys, CostKind);
648+
BaseT::getScalarizationOverhead(VTy, Args, Tys, CostKind);
613649
}
614650
if ((SignedDivRem || UnsignedDivRem) && VF > 4)
615651
// Temporary hack: disable high vectorization factors with integer
@@ -636,7 +672,7 @@ InstructionCost SystemZTTIImpl::getArithmeticInstrCost(
636672
SmallVector<Type *> Tys(Args.size(), Ty);
637673
InstructionCost Cost =
638674
(VF * ScalarCost) +
639-
getScalarizationOverhead(VTy, Args, Tys, CostKind);
675+
BaseT::getScalarizationOverhead(VTy, Args, Tys, CostKind);
640676
// FIXME: VF 2 for these FP operations are currently just as
641677
// expensive as for VF 4.
642678
if (VF == 2)
@@ -654,8 +690,9 @@ InstructionCost SystemZTTIImpl::getArithmeticInstrCost(
654690
// There is no native support for FRem.
655691
if (Opcode == Instruction::FRem) {
656692
SmallVector<Type *> Tys(Args.size(), Ty);
657-
InstructionCost Cost = (VF * LIBCALL_COST) +
658-
getScalarizationOverhead(VTy, Args, Tys, CostKind);
693+
InstructionCost Cost =
694+
(VF * LIBCALL_COST) +
695+
BaseT::getScalarizationOverhead(VTy, Args, Tys, CostKind);
659696
// FIXME: VF 2 for float is currently just as expensive as for VF 4.
660697
if (VF == 2 && ScalarBits == 32)
661698
Cost *= 2;
@@ -975,10 +1012,10 @@ InstructionCost SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
9751012
(Opcode == Instruction::FPToSI || Opcode == Instruction::FPToUI))
9761013
NeedsExtracts = false;
9771014

978-
TotCost += getScalarizationOverhead(SrcVecTy, /*Insert*/ false,
979-
NeedsExtracts, CostKind);
980-
TotCost += getScalarizationOverhead(DstVecTy, NeedsInserts,
981-
/*Extract*/ false, CostKind);
1015+
TotCost += BaseT::getScalarizationOverhead(SrcVecTy, /*Insert*/ false,
1016+
NeedsExtracts, CostKind);
1017+
TotCost += BaseT::getScalarizationOverhead(DstVecTy, NeedsInserts,
1018+
/*Extract*/ false, CostKind);
9821019

9831020
// FIXME: VF 2 for float<->i32 is currently just as expensive as for VF 4.
9841021
if (VF == 2 && SrcScalarBits == 32 && DstScalarBits == 32)
@@ -990,8 +1027,8 @@ InstructionCost SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
9901027
if (Opcode == Instruction::FPTrunc) {
9911028
if (SrcScalarBits == 128) // fp128 -> double/float + inserts of elements.
9921029
return VF /*ldxbr/lexbr*/ +
993-
getScalarizationOverhead(DstVecTy, /*Insert*/ true,
994-
/*Extract*/ false, CostKind);
1030+
BaseT::getScalarizationOverhead(DstVecTy, /*Insert*/ true,
1031+
/*Extract*/ false, CostKind);
9951032
else // double -> float
9961033
return VF / 2 /*vledb*/ + std::max(1U, VF / 4 /*vperm*/);
9971034
}
@@ -1004,8 +1041,8 @@ InstructionCost SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
10041041
return VF * 2;
10051042
}
10061043
// -> fp128. VF * lxdb/lxeb + extraction of elements.
1007-
return VF + getScalarizationOverhead(SrcVecTy, /*Insert*/ false,
1008-
/*Extract*/ true, CostKind);
1044+
return VF + BaseT::getScalarizationOverhead(SrcVecTy, /*Insert*/ false,
1045+
/*Extract*/ true, CostKind);
10091046
}
10101047
}
10111048

@@ -1114,10 +1151,17 @@ InstructionCost SystemZTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
11141151
TTI::TargetCostKind CostKind,
11151152
unsigned Index, Value *Op0,
11161153
Value *Op1) {
1117-
// vlvgp will insert two grs into a vector register, so only count half the
1118-
// number of instructions.
1119-
if (Opcode == Instruction::InsertElement && Val->isIntOrIntVectorTy(64))
1120-
return ((Index % 2 == 0) ? 1 : 0);
1154+
if (Opcode == Instruction::InsertElement) {
1155+
// Vector Element Load.
1156+
if (Op1 != nullptr && isFreeEltLoad(Op1))
1157+
return 0;
1158+
1159+
// vlvgp will insert two grs into a vector register, so count half the
1160+
// number of instructions as an estimate when we don't have the full
1161+
// picture (as in getScalarizationOverhead()).
1162+
if (Val->isIntOrIntVectorTy(64))
1163+
return ((Index % 2 == 0) ? 1 : 0);
1164+
}
11211165

11221166
if (Opcode == Instruction::ExtractElement) {
11231167
int Cost = ((getScalarSizeInBits(Val) == 1) ? 2 /*+test-under-mask*/ : 1);

llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@ class SystemZTTIImpl : public BasicTTIImplBase<SystemZTTIImpl> {
8181
bool hasDivRemOp(Type *DataType, bool IsSigned);
8282
bool prefersVectorizedAddressing() { return false; }
8383
bool LSRWithInstrQueries() { return true; }
84+
InstructionCost getScalarizationOverhead(VectorType *Ty,
85+
const APInt &DemandedElts,
86+
bool Insert, bool Extract,
87+
TTI::TargetCostKind CostKind,
88+
ArrayRef<Value *> VL = {});
8489
bool supportsEfficientVectorElementLoadStore() { return true; }
8590
bool enableInterleavedAccessVectorization() { return true; }
8691

0 commit comments

Comments
 (0)