Skip to content

Commit 482785e

Browse files
committed
[TTI] Use MemIntrinsicCostAttributes for getGatherScatterOpCost
- Following from llvm#168029. This is a step toward a unified interface for masked/gather-scatter/strided/expand-compress cost modeling. - Replace the ad-hoc parameter list with a single attributes object. API change: ``` - InstructionCost getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask, - Alignment, CostKind, Inst); + InstructionCost getGatherScatterOpCost(MemIntrinsicCostAttributes, + CostKind); ``` Notes: - NFCI intended: callers populate MemIntrinsicCostAttributes with same information as before.
1 parent 961940e commit 482785e

17 files changed

+141
-92
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,23 @@ struct HardwareLoopInfo {
125125

126126
/// Information for memory intrinsic cost model.
127127
class MemIntrinsicCostAttributes {
128+
/// Optional context instruction, if one exists, e.g. the
129+
/// load/store to transform to the intrinsic.
130+
const Instruction *I = nullptr;
131+
132+
/// Address in memory.
133+
const Value *Ptr = nullptr;
134+
128135
/// Vector type of the data to be loaded or stored.
129136
Type *DataTy = nullptr;
130137

131138
/// ID of the memory intrinsic.
132139
Intrinsic::ID IID;
133140

141+
/// True when the memory access is predicated with a mask
142+
/// that is not a compile-time constant.
143+
bool VariableMask = true;
144+
134145
/// Address space of the pointer.
135146
unsigned AddressSpace = 0;
136147

@@ -143,8 +154,18 @@ class MemIntrinsicCostAttributes {
143154
: DataTy(DataTy), IID(Id), AddressSpace(AddressSpace),
144155
Alignment(Alignment) {}
145156

157+
LLVM_ABI MemIntrinsicCostAttributes(Intrinsic::ID Id, Type *DataTy,
158+
const Value *Ptr, bool VariableMask,
159+
Align Alignment,
160+
const Instruction *I = nullptr)
161+
: I(I), Ptr(Ptr), DataTy(DataTy), IID(Id), VariableMask(VariableMask),
162+
Alignment(Alignment) {}
163+
146164
Intrinsic::ID getID() const { return IID; }
165+
const Instruction *getInst() const { return I; }
166+
const Value *getPointer() const { return Ptr; }
147167
Type *getDataType() const { return DataTy; }
168+
bool getVariableMask() const { return VariableMask; }
148169
unsigned getAddressSpace() const { return AddressSpace; }
149170
Align getAlignment() const { return Alignment; }
150171
};
@@ -1595,9 +1616,8 @@ class TargetTransformInfo {
15951616
/// \p I - the optional original context instruction, if one exists, e.g. the
15961617
/// load/store to transform or the call to the gather/scatter intrinsic
15971618
LLVM_ABI InstructionCost getGatherScatterOpCost(
1598-
unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
1599-
Align Alignment, TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput,
1600-
const Instruction *I = nullptr) const;
1619+
const MemIntrinsicCostAttributes &MICA,
1620+
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const;
16011621

16021622
/// \return The cost of Expand Load or Compress Store operation
16031623
/// \p Opcode - is a type of memory access Load or Store

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -848,10 +848,8 @@ class TargetTransformInfoImplBase {
848848
}
849849

850850
virtual InstructionCost
851-
getGatherScatterOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr,
852-
bool VariableMask, Align Alignment,
853-
TTI::TargetCostKind CostKind,
854-
const Instruction *I = nullptr) const {
851+
getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
852+
TTI::TargetCostKind CostKind) const {
855853
return 1;
856854
}
857855

llvm/include/llvm/CodeGen/BasicTTIImpl.h

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1571,10 +1571,15 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
15711571
}
15721572

15731573
InstructionCost
1574-
getGatherScatterOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr,
1575-
bool VariableMask, Align Alignment,
1576-
TTI::TargetCostKind CostKind,
1577-
const Instruction *I = nullptr) const override {
1574+
getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
1575+
TTI::TargetCostKind CostKind) const override {
1576+
unsigned Opcode = (MICA.getID() == Intrinsic::masked_gather ||
1577+
MICA.getID() == Intrinsic::vp_gather)
1578+
? Instruction::Load
1579+
: Instruction::Store;
1580+
Type *DataTy = MICA.getDataType();
1581+
bool VariableMask = MICA.getVariableMask();
1582+
Align Alignment = MICA.getAlignment();
15781583
return getCommonMaskedMemoryOpCost(Opcode, DataTy, Alignment, VariableMask,
15791584
true, CostKind);
15801585
}
@@ -1598,8 +1603,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
15981603
// For a target without strided memory operations (or for an illegal
15991604
// operation type on one which does), assume we lower to a gather/scatter
16001605
// operation. (Which may in turn be scalarized.)
1601-
return thisT()->getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
1602-
Alignment, CostKind, I);
1606+
unsigned IID = Opcode == Instruction::Load ? Intrinsic::masked_gather
1607+
: Intrinsic::masked_scatter;
1608+
return thisT()->getGatherScatterOpCost(
1609+
{IID, DataTy, Ptr, VariableMask, Alignment, I}, CostKind);
16031610
}
16041611

16051612
InstructionCost getInterleavedMemoryOpCost(
@@ -1826,8 +1833,9 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
18261833
Alignment = VPI->getPointerAlignment().valueOrOne();
18271834
bool VarMask = isa<Constant>(ICA.getArgs()[2]);
18281835
return thisT()->getGatherScatterOpCost(
1829-
Instruction::Store, ICA.getArgTypes()[0], ICA.getArgs()[1], VarMask,
1830-
Alignment, CostKind, nullptr);
1836+
{ICA.getID(), ICA.getArgTypes()[0], ICA.getArgs()[1], VarMask,
1837+
Alignment, nullptr},
1838+
CostKind);
18311839
}
18321840
if (ICA.getID() == Intrinsic::vp_gather) {
18331841
if (ICA.isTypeBasedOnly()) {
@@ -1842,8 +1850,9 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
18421850
Alignment = VPI->getPointerAlignment().valueOrOne();
18431851
bool VarMask = isa<Constant>(ICA.getArgs()[1]);
18441852
return thisT()->getGatherScatterOpCost(
1845-
Instruction::Load, ICA.getReturnType(), ICA.getArgs()[0], VarMask,
1846-
Alignment, CostKind, nullptr);
1853+
{ICA.getID(), ICA.getReturnType(), ICA.getArgs()[0], VarMask,
1854+
Alignment, nullptr},
1855+
CostKind);
18471856
}
18481857

18491858
if (ICA.getID() == Intrinsic::vp_select ||
@@ -1948,16 +1957,16 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
19481957
const Value *Mask = Args[2];
19491958
bool VarMask = !isa<Constant>(Mask);
19501959
Align Alignment = I->getParamAlign(1).valueOrOne();
1951-
return thisT()->getGatherScatterOpCost(Instruction::Store,
1952-
ICA.getArgTypes()[0], Args[1],
1953-
VarMask, Alignment, CostKind, I);
1960+
return thisT()->getGatherScatterOpCost(
1961+
{IID, ICA.getArgTypes()[0], Args[1], VarMask, Alignment, I},
1962+
CostKind);
19541963
}
19551964
case Intrinsic::masked_gather: {
19561965
const Value *Mask = Args[1];
19571966
bool VarMask = !isa<Constant>(Mask);
19581967
Align Alignment = I->getParamAlign(0).valueOrOne();
1959-
return thisT()->getGatherScatterOpCost(Instruction::Load, RetTy, Args[0],
1960-
VarMask, Alignment, CostKind, I);
1968+
return thisT()->getGatherScatterOpCost(
1969+
{IID, RetTy, Args[0], VarMask, Alignment, I}, CostKind);
19611970
}
19621971
case Intrinsic::masked_compressstore: {
19631972
const Value *Data = Args[0];

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,10 +1191,9 @@ InstructionCost TargetTransformInfo::getMaskedMemoryOpCost(
11911191
}
11921192

11931193
InstructionCost TargetTransformInfo::getGatherScatterOpCost(
1194-
unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
1195-
Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const {
1196-
InstructionCost Cost = TTIImpl->getGatherScatterOpCost(
1197-
Opcode, DataTy, Ptr, VariableMask, Alignment, CostKind, I);
1194+
const MemIntrinsicCostAttributes &MICA,
1195+
TTI::TargetCostKind CostKind) const {
1196+
InstructionCost Cost = TTIImpl->getGatherScatterOpCost(MICA, CostKind);
11981197
assert((!Cost.isValid() || Cost >= 0) &&
11991198
"TTI should not produce negative costs!");
12001199
return Cost;

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4777,12 +4777,21 @@ static unsigned getSVEGatherScatterOverhead(unsigned Opcode,
47774777
}
47784778
}
47794779

4780-
InstructionCost AArch64TTIImpl::getGatherScatterOpCost(
4781-
unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
4782-
Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const {
4780+
InstructionCost
4781+
AArch64TTIImpl::getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
4782+
TTI::TargetCostKind CostKind) const {
4783+
4784+
unsigned Opcode = (MICA.getID() == Intrinsic::masked_gather ||
4785+
MICA.getID() == Intrinsic::vp_gather)
4786+
? Instruction::Load
4787+
: Instruction::Store;
4788+
4789+
Type *DataTy = MICA.getDataType();
4790+
Align Alignment = MICA.getAlignment();
4791+
const Instruction *I = MICA.getInst();
4792+
47834793
if (useNeonVector(DataTy) || !isLegalMaskedGatherScatter(DataTy))
4784-
return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
4785-
Alignment, CostKind, I);
4794+
return BaseT::getGatherScatterOpCost(MICA, CostKind);
47864795
auto *VT = cast<VectorType>(DataTy);
47874796
auto LT = getTypeLegalizationCost(DataTy);
47884797
if (!LT.first.isValid())

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,10 +192,8 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
192192
TTI::TargetCostKind CostKind) const override;
193193

194194
InstructionCost
195-
getGatherScatterOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr,
196-
bool VariableMask, Align Alignment,
197-
TTI::TargetCostKind CostKind,
198-
const Instruction *I = nullptr) const override;
195+
getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
196+
TTI::TargetCostKind CostKind) const override;
199197

200198
bool isExtPartOfAvgExpr(const Instruction *ExtUser, Type *Dst,
201199
Type *Src) const;

llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1693,13 +1693,19 @@ InstructionCost ARMTTIImpl::getInterleavedMemoryOpCost(
16931693
UseMaskForCond, UseMaskForGaps);
16941694
}
16951695

1696-
InstructionCost ARMTTIImpl::getGatherScatterOpCost(
1697-
unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
1698-
Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const {
1696+
InstructionCost
1697+
ARMTTIImpl::getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
1698+
TTI::TargetCostKind CostKind) const {
1699+
1700+
Type *DataTy = MICA.getDataType();
1701+
const Value *Ptr = MICA.getPointer();
1702+
bool VariableMask = MICA.getVariableMask();
1703+
Align Alignment = MICA.getAlignment();
1704+
const Instruction *I = MICA.getInst();
1705+
16991706
using namespace PatternMatch;
17001707
if (!ST->hasMVEIntegerOps() || !EnableMaskedGatherScatters)
1701-
return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
1702-
Alignment, CostKind, I);
1708+
return BaseT::getGatherScatterOpCost(MICA, CostKind);
17031709

17041710
assert(DataTy->isVectorTy() && "Can't do gather/scatters on scalar!");
17051711
auto *VTy = cast<FixedVectorType>(DataTy);

llvm/lib/Target/ARM/ARMTargetTransformInfo.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -284,10 +284,8 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
284284
bool UseMaskForCond = false, bool UseMaskForGaps = false) const override;
285285

286286
InstructionCost
287-
getGatherScatterOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr,
288-
bool VariableMask, Align Alignment,
289-
TTI::TargetCostKind CostKind,
290-
const Instruction *I = nullptr) const override;
287+
getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
288+
TTI::TargetCostKind CostKind) const override;
291289

292290
InstructionCost
293291
getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,

llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -238,11 +238,10 @@ HexagonTTIImpl::getShuffleCost(TTI::ShuffleKind Kind, VectorType *DstTy,
238238
return 1;
239239
}
240240

241-
InstructionCost HexagonTTIImpl::getGatherScatterOpCost(
242-
unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
243-
Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const {
244-
return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
245-
Alignment, CostKind, I);
241+
InstructionCost
242+
HexagonTTIImpl::getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
243+
TTI::TargetCostKind CostKind) const {
244+
return BaseT::getGatherScatterOpCost(MICA, CostKind);
246245
}
247246

248247
InstructionCost HexagonTTIImpl::getInterleavedMemoryOpCost(

llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,9 @@ class HexagonTTIImpl final : public BasicTTIImplBase<HexagonTTIImpl> {
127127
ArrayRef<int> Mask, TTI::TargetCostKind CostKind, int Index,
128128
VectorType *SubTp, ArrayRef<const Value *> Args = {},
129129
const Instruction *CxtI = nullptr) const override;
130-
InstructionCost getGatherScatterOpCost(unsigned Opcode, Type *DataTy,
131-
const Value *Ptr, bool VariableMask,
132-
Align Alignment,
133-
TTI::TargetCostKind CostKind,
134-
const Instruction *I) const override;
130+
InstructionCost
131+
getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
132+
TTI::TargetCostKind CostKind) const override;
135133
InstructionCost getInterleavedMemoryOpCost(
136134
unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
137135
Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,

0 commit comments

Comments
 (0)