Skip to content

Commit 114ca65

Browse files
authored
[TTI] Use MemIntrinsicCostAttributes for getStridedOpCost (#170436)
- Following #168029. This is a step toward a unified interface for masked/gather-scatter/strided/expand-compress cost modeling. - Replace the ad-hoc parameter list with a single attributes object. API change: ``` - InstructionCost getStridedMemoryOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask, Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I = nullptr); + InstructionCost getStridedMemoryOpCost(MemIntrinsicCostAttributes, + CostKind); ``` Notes: - NFCI intended: callers populate MemIntrinsicCostAttributes with same information as before.
1 parent 9296223 commit 114ca65

File tree

4 files changed

+30
-38
lines changed

4 files changed

+30
-38
lines changed

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -864,10 +864,8 @@ class TargetTransformInfoImplBase {
864864
}
865865

866866
virtual InstructionCost
867-
getStridedMemoryOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr,
868-
bool VariableMask, Align Alignment,
869-
TTI::TargetCostKind CostKind,
870-
const Instruction *I = nullptr) const {
867+
getStridedMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
868+
TTI::TargetCostKind CostKind) const {
871869
return InstructionCost::getInvalid();
872870
}
873871

llvm/include/llvm/CodeGen/BasicTTIImpl.h

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1599,19 +1599,19 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
15991599
/*IsGatherScatter*/ true, CostKind);
16001600
}
16011601

1602-
InstructionCost getStridedMemoryOpCost(unsigned Opcode, Type *DataTy,
1603-
const Value *Ptr, bool VariableMask,
1604-
Align Alignment,
1605-
TTI::TargetCostKind CostKind,
1606-
const Instruction *I) const override {
1602+
InstructionCost
1603+
getStridedMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
1604+
TTI::TargetCostKind CostKind) const override {
16071605
// For a target without strided memory operations (or for an illegal
16081606
// operation type on one which does), assume we lower to a gather/scatter
16091607
// operation. (Which may in turn be scalarized.)
1610-
unsigned IID = Opcode == Instruction::Load ? Intrinsic::masked_gather
1611-
: Intrinsic::masked_scatter;
1608+
unsigned IID = MICA.getID() == Intrinsic::experimental_vp_strided_load
1609+
? Intrinsic::masked_gather
1610+
: Intrinsic::masked_scatter;
16121611
return thisT()->getGatherScatterOpCost(
1613-
MemIntrinsicCostAttributes(IID, DataTy, Ptr, VariableMask, Alignment,
1614-
I),
1612+
MemIntrinsicCostAttributes(IID, MICA.getDataType(), MICA.getPointer(),
1613+
MICA.getVariableMask(), MICA.getAlignment(),
1614+
MICA.getInst()),
16151615
CostKind);
16161616
}
16171617

@@ -3062,21 +3062,11 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
30623062
getMemIntrinsicInstrCost(const MemIntrinsicCostAttributes &MICA,
30633063
TTI::TargetCostKind CostKind) const override {
30643064
unsigned Id = MICA.getID();
3065-
Type *DataTy = MICA.getDataType();
3066-
const Value *Ptr = MICA.getPointer();
3067-
const Instruction *I = MICA.getInst();
3068-
bool VariableMask = MICA.getVariableMask();
3069-
Align Alignment = MICA.getAlignment();
30703065

30713066
switch (Id) {
30723067
case Intrinsic::experimental_vp_strided_load:
3073-
case Intrinsic::experimental_vp_strided_store: {
3074-
unsigned Opcode = Id == Intrinsic::experimental_vp_strided_load
3075-
? Instruction::Load
3076-
: Instruction::Store;
3077-
return thisT()->getStridedMemoryOpCost(Opcode, DataTy, Ptr, VariableMask,
3078-
Alignment, CostKind, I);
3079-
}
3068+
case Intrinsic::experimental_vp_strided_store:
3069+
return thisT()->getStridedMemoryOpCost(MICA, CostKind);
30803070
case Intrinsic::masked_scatter:
30813071
case Intrinsic::masked_gather:
30823072
case Intrinsic::vp_scatter:

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,14 +1212,20 @@ InstructionCost RISCVTTIImpl::getExpandCompressMemoryOpCost(
12121212
LT.first * getRISCVInstructionCost(Opcodes, LT.second, CostKind);
12131213
}
12141214

1215-
InstructionCost RISCVTTIImpl::getStridedMemoryOpCost(
1216-
unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
1217-
Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const {
1218-
if (((Opcode == Instruction::Load || Opcode == Instruction::Store) &&
1219-
!isLegalStridedLoadStore(DataTy, Alignment)) ||
1220-
(Opcode != Instruction::Load && Opcode != Instruction::Store))
1221-
return BaseT::getStridedMemoryOpCost(Opcode, DataTy, Ptr, VariableMask,
1222-
Alignment, CostKind, I);
1215+
InstructionCost
1216+
RISCVTTIImpl::getStridedMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
1217+
TTI::TargetCostKind CostKind) const {
1218+
1219+
unsigned Opcode = MICA.getID() == Intrinsic::experimental_vp_strided_load
1220+
? Instruction::Load
1221+
: Instruction::Store;
1222+
1223+
Type *DataTy = MICA.getDataType();
1224+
Align Alignment = MICA.getAlignment();
1225+
const Instruction *I = MICA.getInst();
1226+
1227+
if (!isLegalStridedLoadStore(DataTy, Alignment))
1228+
return BaseT::getStridedMemoryOpCost(MICA, CostKind);
12231229

12241230
if (CostKind == TTI::TCK_CodeSize)
12251231
return TTI::TCC_Basic;

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -202,11 +202,9 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
202202
getExpandCompressMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
203203
TTI::TargetCostKind CostKind) const override;
204204

205-
InstructionCost getStridedMemoryOpCost(unsigned Opcode, Type *DataTy,
206-
const Value *Ptr, bool VariableMask,
207-
Align Alignment,
208-
TTI::TargetCostKind CostKind,
209-
const Instruction *I) const override;
205+
InstructionCost
206+
getStridedMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
207+
TTI::TargetCostKind CostKind) const override;
210208

211209
InstructionCost
212210
getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) const override;

0 commit comments

Comments
 (0)