Skip to content

Commit 879f460

Browse files
committed
[TTI] Use MemIntrinsicCostAttributes for getGatherScatterOpCost
- Following from #168029. This is a step toward a unified interface for masked/gather-scatter/strided/expand-compress cost modeling. - Replace the ad-hoc parameter list with a single attributes object. API change: ``` - InstructionCost getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask, - Alignment, CostKind, Inst); + InstructionCost getGatherScatterOpCost(MemIntrinsicCostAttributes, + CostKind); ``` Notes: - NFCI intended: callers populate MemIntrinsicCostAttributes with same information as before.
1 parent b9bdec3 commit 879f460

12 files changed

+83
-69
lines changed

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -852,10 +852,8 @@ class TargetTransformInfoImplBase {
852852
}
853853

854854
virtual InstructionCost
855-
getGatherScatterOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr,
856-
bool VariableMask, Align Alignment,
857-
TTI::TargetCostKind CostKind,
858-
const Instruction *I = nullptr) const {
855+
getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
856+
TTI::TargetCostKind CostKind) const {
859857
return 1;
860858
}
861859

llvm/include/llvm/CodeGen/BasicTTIImpl.h

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1571,10 +1571,15 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
15711571
}
15721572

15731573
InstructionCost
1574-
getGatherScatterOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr,
1575-
bool VariableMask, Align Alignment,
1576-
TTI::TargetCostKind CostKind,
1577-
const Instruction *I = nullptr) const override {
1574+
getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
1575+
TTI::TargetCostKind CostKind) const override {
1576+
unsigned Opcode = (MICA.getID() == Intrinsic::masked_gather ||
1577+
MICA.getID() == Intrinsic::vp_gather)
1578+
? Instruction::Load
1579+
: Instruction::Store;
1580+
Type *DataTy = MICA.getDataType();
1581+
bool VariableMask = MICA.getVariableMask();
1582+
Align Alignment = MICA.getAlignment();
15781583
return getCommonMaskedMemoryOpCost(Opcode, DataTy, Alignment, VariableMask,
15791584
true, CostKind);
15801585
}
@@ -1602,8 +1607,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
16021607
// For a target without strided memory operations (or for an illegal
16031608
// operation type on one which does), assume we lower to a gather/scatter
16041609
// operation. (Which may in turn be scalarized.)
1605-
return thisT()->getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
1606-
Alignment, CostKind, I);
1610+
unsigned IID = Opcode == Instruction::Load ? Intrinsic::masked_gather
1611+
: Intrinsic::masked_scatter;
1612+
return thisT()->getGatherScatterOpCost(
1613+
MemIntrinsicCostAttributes(IID, DataTy, Ptr, VariableMask, Alignment,
1614+
I),
1615+
CostKind);
16071616
}
16081617

16091618
InstructionCost getInterleavedMemoryOpCost(
@@ -3062,14 +3071,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
30623071
case Intrinsic::masked_scatter:
30633072
case Intrinsic::masked_gather:
30643073
case Intrinsic::vp_scatter:
3065-
case Intrinsic::vp_gather: {
3066-
unsigned Opcode =
3067-
(Id == Intrinsic::masked_gather || Id == Intrinsic::vp_gather)
3068-
? Instruction::Load
3069-
: Instruction::Store;
3070-
return thisT()->getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
3071-
Alignment, CostKind, I);
3072-
}
3074+
case Intrinsic::vp_gather:
3075+
return thisT()->getGatherScatterOpCost(MICA, CostKind);
30733076
case Intrinsic::masked_load:
30743077
case Intrinsic::masked_store:
30753078
return thisT()->getMaskedMemoryOpCost(MICA, CostKind);

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4782,12 +4782,21 @@ static unsigned getSVEGatherScatterOverhead(unsigned Opcode,
47824782
}
47834783
}
47844784

4785-
InstructionCost AArch64TTIImpl::getGatherScatterOpCost(
4786-
unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
4787-
Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const {
4785+
InstructionCost
4786+
AArch64TTIImpl::getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
4787+
TTI::TargetCostKind CostKind) const {
4788+
4789+
unsigned Opcode = (MICA.getID() == Intrinsic::masked_gather ||
4790+
MICA.getID() == Intrinsic::vp_gather)
4791+
? Instruction::Load
4792+
: Instruction::Store;
4793+
4794+
Type *DataTy = MICA.getDataType();
4795+
Align Alignment = MICA.getAlignment();
4796+
const Instruction *I = MICA.getInst();
4797+
47884798
if (useNeonVector(DataTy) || !isLegalMaskedGatherScatter(DataTy))
4789-
return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
4790-
Alignment, CostKind, I);
4799+
return BaseT::getGatherScatterOpCost(MICA, CostKind);
47914800
auto *VT = cast<VectorType>(DataTy);
47924801
auto LT = getTypeLegalizationCost(DataTy);
47934802
if (!LT.first.isValid())

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,10 +192,8 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
192192
TTI::TargetCostKind CostKind) const override;
193193

194194
InstructionCost
195-
getGatherScatterOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr,
196-
bool VariableMask, Align Alignment,
197-
TTI::TargetCostKind CostKind,
198-
const Instruction *I = nullptr) const override;
195+
getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
196+
TTI::TargetCostKind CostKind) const override;
199197

200198
bool isExtPartOfAvgExpr(const Instruction *ExtUser, Type *Dst,
201199
Type *Src) const;

llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1694,13 +1694,19 @@ InstructionCost ARMTTIImpl::getInterleavedMemoryOpCost(
16941694
UseMaskForCond, UseMaskForGaps);
16951695
}
16961696

1697-
InstructionCost ARMTTIImpl::getGatherScatterOpCost(
1698-
unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
1699-
Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const {
1697+
InstructionCost
1698+
ARMTTIImpl::getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
1699+
TTI::TargetCostKind CostKind) const {
1700+
1701+
Type *DataTy = MICA.getDataType();
1702+
const Value *Ptr = MICA.getPointer();
1703+
bool VariableMask = MICA.getVariableMask();
1704+
Align Alignment = MICA.getAlignment();
1705+
const Instruction *I = MICA.getInst();
1706+
17001707
using namespace PatternMatch;
17011708
if (!ST->hasMVEIntegerOps() || !EnableMaskedGatherScatters)
1702-
return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
1703-
Alignment, CostKind, I);
1709+
return BaseT::getGatherScatterOpCost(MICA, CostKind);
17041710

17051711
assert(DataTy->isVectorTy() && "Can't do gather/scatters on scalar!");
17061712
auto *VTy = cast<FixedVectorType>(DataTy);

llvm/lib/Target/ARM/ARMTargetTransformInfo.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -288,10 +288,8 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
288288
bool UseMaskForCond = false, bool UseMaskForGaps = false) const override;
289289

290290
InstructionCost
291-
getGatherScatterOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr,
292-
bool VariableMask, Align Alignment,
293-
TTI::TargetCostKind CostKind,
294-
const Instruction *I = nullptr) const override;
291+
getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
292+
TTI::TargetCostKind CostKind) const override;
295293

296294
InstructionCost
297295
getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,

llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -238,11 +238,10 @@ HexagonTTIImpl::getShuffleCost(TTI::ShuffleKind Kind, VectorType *DstTy,
238238
return 1;
239239
}
240240

241-
InstructionCost HexagonTTIImpl::getGatherScatterOpCost(
242-
unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
243-
Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const {
244-
return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
245-
Alignment, CostKind, I);
241+
InstructionCost
242+
HexagonTTIImpl::getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
243+
TTI::TargetCostKind CostKind) const {
244+
return BaseT::getGatherScatterOpCost(MICA, CostKind);
246245
}
247246

248247
InstructionCost HexagonTTIImpl::getInterleavedMemoryOpCost(

llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,9 @@ class HexagonTTIImpl final : public BasicTTIImplBase<HexagonTTIImpl> {
127127
ArrayRef<int> Mask, TTI::TargetCostKind CostKind, int Index,
128128
VectorType *SubTp, ArrayRef<const Value *> Args = {},
129129
const Instruction *CxtI = nullptr) const override;
130-
InstructionCost getGatherScatterOpCost(unsigned Opcode, Type *DataTy,
131-
const Value *Ptr, bool VariableMask,
132-
Align Alignment,
133-
TTI::TargetCostKind CostKind,
134-
const Instruction *I) const override;
130+
InstructionCost
131+
getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
132+
TTI::TargetCostKind CostKind) const override;
135133
InstructionCost getInterleavedMemoryOpCost(
136134
unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
137135
Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,19 +1120,24 @@ InstructionCost RISCVTTIImpl::getInterleavedMemoryOpCost(
11201120
return MemCost + ShuffleCost;
11211121
}
11221122

1123-
InstructionCost RISCVTTIImpl::getGatherScatterOpCost(
1124-
unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
1125-
Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const {
1123+
InstructionCost
1124+
RISCVTTIImpl::getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
1125+
TTI::TargetCostKind CostKind) const {
1126+
1127+
bool IsLoad = MICA.getID() == Intrinsic::masked_gather ||
1128+
MICA.getID() == Intrinsic::vp_gather;
1129+
unsigned Opcode = IsLoad ? Instruction::Load : Instruction::Store;
1130+
Type *DataTy = MICA.getDataType();
1131+
Align Alignment = MICA.getAlignment();
1132+
const Instruction *I = MICA.getInst();
11261133
if (CostKind != TTI::TCK_RecipThroughput)
1127-
return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
1128-
Alignment, CostKind, I);
1134+
return BaseT::getGatherScatterOpCost(MICA, CostKind);
11291135

11301136
if ((Opcode == Instruction::Load &&
11311137
!isLegalMaskedGather(DataTy, Align(Alignment))) ||
11321138
(Opcode == Instruction::Store &&
11331139
!isLegalMaskedScatter(DataTy, Align(Alignment))))
1134-
return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
1135-
Alignment, CostKind, I);
1140+
return BaseT::getGatherScatterOpCost(MICA, CostKind);
11361141

11371142
// Cost is proportional to the number of memory operations implied. For
11381143
// scalable vectors, we use an estimate on that number since we don't

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,11 +190,9 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
190190
Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
191191
bool UseMaskForCond = false, bool UseMaskForGaps = false) const override;
192192

193-
InstructionCost getGatherScatterOpCost(unsigned Opcode, Type *DataTy,
194-
const Value *Ptr, bool VariableMask,
195-
Align Alignment,
196-
TTI::TargetCostKind CostKind,
197-
const Instruction *I) const override;
193+
InstructionCost
194+
getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
195+
TTI::TargetCostKind CostKind) const override;
198196

199197
InstructionCost
200198
getExpandCompressMemoryOpCost(const MemIntrinsicCostAttributes &MICA,

0 commit comments

Comments
 (0)