From 879f46099916134611d7b740fdd494607380cdd4 Mon Sep 17 00:00:00 2001 From: ShihPo Hung Date: Fri, 14 Nov 2025 06:19:43 -0800 Subject: [PATCH] [TTI] Use MemIntrinsicCostAttributes for getGatherScatterOpCost - Following from #168029. This is a step toward a unified interface for masked/gather-scatter/strided/expand-compress cost modeling. - Replace the ad-hoc parameter list with a single attributes object. API change: ``` - InstructionCost getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask, - Alignment, CostKind, Inst); + InstructionCost getGatherScatterOpCost(MemIntrinsicCostAttributes, + CostKind); ``` Notes: - NFCI intended: callers populate MemIntrinsicCostAttributes with same information as before. --- .../llvm/Analysis/TargetTransformInfoImpl.h | 6 ++-- llvm/include/llvm/CodeGen/BasicTTIImpl.h | 31 ++++++++++--------- .../AArch64/AArch64TargetTransformInfo.cpp | 19 +++++++++--- .../AArch64/AArch64TargetTransformInfo.h | 6 ++-- .../lib/Target/ARM/ARMTargetTransformInfo.cpp | 16 +++++++--- llvm/lib/Target/ARM/ARMTargetTransformInfo.h | 6 ++-- .../Hexagon/HexagonTargetTransformInfo.cpp | 9 +++--- .../Hexagon/HexagonTargetTransformInfo.h | 8 ++--- .../Target/RISCV/RISCVTargetTransformInfo.cpp | 19 +++++++----- .../Target/RISCV/RISCVTargetTransformInfo.h | 8 ++--- .../lib/Target/X86/X86TargetTransformInfo.cpp | 16 ++++++---- llvm/lib/Target/X86/X86TargetTransformInfo.h | 8 ++--- 12 files changed, 83 insertions(+), 69 deletions(-) diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h index 624302bc6d0a3..9b2a7f432a544 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -852,10 +852,8 @@ class TargetTransformInfoImplBase { } virtual InstructionCost - getGatherScatterOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr, - bool VariableMask, Align Alignment, - TTI::TargetCostKind CostKind, - const Instruction *I = nullptr) const { + getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA, + TTI::TargetCostKind CostKind) const { return 1; } diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h index b1beb68feca46..dc0b729000881 100644 --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -1571,10 +1571,15 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase { } InstructionCost - getGatherScatterOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr, - bool VariableMask, Align Alignment, - TTI::TargetCostKind CostKind, - const Instruction *I = nullptr) const override { + getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA, + TTI::TargetCostKind CostKind) const override { + unsigned Opcode = (MICA.getID() == Intrinsic::masked_gather || + MICA.getID() == Intrinsic::vp_gather) + ? Instruction::Load + : Instruction::Store; + Type *DataTy = MICA.getDataType(); + bool VariableMask = MICA.getVariableMask(); + Align Alignment = MICA.getAlignment(); return getCommonMaskedMemoryOpCost(Opcode, DataTy, Alignment, VariableMask, true, CostKind); } @@ -1602,8 +1607,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase { // For a target without strided memory operations (or for an illegal // operation type on one which does), assume we lower to a gather/scatter // operation. (Which may in turn be scalarized.) - return thisT()->getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask, - Alignment, CostKind, I); + unsigned IID = Opcode == Instruction::Load ? Intrinsic::masked_gather + : Intrinsic::masked_scatter; + return thisT()->getGatherScatterOpCost( + MemIntrinsicCostAttributes(IID, DataTy, Ptr, VariableMask, Alignment, + I), + CostKind); } InstructionCost getInterleavedMemoryOpCost( @@ -3062,14 +3071,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase { case Intrinsic::masked_scatter: case Intrinsic::masked_gather: case Intrinsic::vp_scatter: - case Intrinsic::vp_gather: { - unsigned Opcode = - (Id == Intrinsic::masked_gather || Id == Intrinsic::vp_gather) - ? Instruction::Load - : Instruction::Store; - return thisT()->getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask, - Alignment, CostKind, I); - } + case Intrinsic::vp_gather: + return thisT()->getGatherScatterOpCost(MICA, CostKind); case Intrinsic::masked_load: case Intrinsic::masked_store: return thisT()->getMaskedMemoryOpCost(MICA, CostKind); diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 52f0b75430d9d..ad8c4f9974916 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -4782,12 +4782,21 @@ static unsigned getSVEGatherScatterOverhead(unsigned Opcode, } } -InstructionCost AArch64TTIImpl::getGatherScatterOpCost( - unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask, - Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const { +InstructionCost +AArch64TTIImpl::getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA, + TTI::TargetCostKind CostKind) const { + + unsigned Opcode = (MICA.getID() == Intrinsic::masked_gather || + MICA.getID() == Intrinsic::vp_gather) + ? Instruction::Load + : Instruction::Store; + + Type *DataTy = MICA.getDataType(); + Align Alignment = MICA.getAlignment(); + const Instruction *I = MICA.getInst(); + if (useNeonVector(DataTy) || !isLegalMaskedGatherScatter(DataTy)) - return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask, - Alignment, CostKind, I); + return BaseT::getGatherScatterOpCost(MICA, CostKind); auto *VT = cast(DataTy); auto LT = getTypeLegalizationCost(DataTy); if (!LT.first.isValid()) diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h index 52fc28a98449b..a15886b2a6afa 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -192,10 +192,8 @@ class AArch64TTIImpl final : public BasicTTIImplBase { TTI::TargetCostKind CostKind) const override; InstructionCost - getGatherScatterOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr, - bool VariableMask, Align Alignment, - TTI::TargetCostKind CostKind, - const Instruction *I = nullptr) const override; + getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA, + TTI::TargetCostKind CostKind) const override; bool isExtPartOfAvgExpr(const Instruction *ExtUser, Type *Dst, Type *Src) const; diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp index fdb0ec40cb41f..bb19c8415858f 100644 --- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp +++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp @@ -1694,13 +1694,19 @@ InstructionCost ARMTTIImpl::getInterleavedMemoryOpCost( UseMaskForCond, UseMaskForGaps); } -InstructionCost ARMTTIImpl::getGatherScatterOpCost( - unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask, - Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const { +InstructionCost +ARMTTIImpl::getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA, + TTI::TargetCostKind CostKind) const { + + Type *DataTy = MICA.getDataType(); + const Value *Ptr = MICA.getPointer(); + bool VariableMask = MICA.getVariableMask(); + Align Alignment = MICA.getAlignment(); + const Instruction *I = MICA.getInst(); + using namespace PatternMatch; if (!ST->hasMVEIntegerOps() || !EnableMaskedGatherScatters) - return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask, - Alignment, CostKind, I); + return BaseT::getGatherScatterOpCost(MICA, CostKind); assert(DataTy->isVectorTy() && "Can't do gather/scatters on scalar!"); auto *VTy = cast(DataTy); diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h index 30f2151b41239..e5f9dd5fe26d9 100644 --- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h +++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h @@ -288,10 +288,8 @@ class ARMTTIImpl final : public BasicTTIImplBase { bool UseMaskForCond = false, bool UseMaskForGaps = false) const override; InstructionCost - getGatherScatterOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr, - bool VariableMask, Align Alignment, - TTI::TargetCostKind CostKind, - const Instruction *I = nullptr) const override; + getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA, + TTI::TargetCostKind CostKind) const override; InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy, diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp index 3f84cbb6555ed..c7d6ee306be84 100644 --- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp +++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp @@ -238,11 +238,10 @@ HexagonTTIImpl::getShuffleCost(TTI::ShuffleKind Kind, VectorType *DstTy, return 1; } -InstructionCost HexagonTTIImpl::getGatherScatterOpCost( - unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask, - Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const { - return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask, - Alignment, CostKind, I); +InstructionCost +HexagonTTIImpl::getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA, + TTI::TargetCostKind CostKind) const { + return BaseT::getGatherScatterOpCost(MICA, CostKind); } InstructionCost HexagonTTIImpl::getInterleavedMemoryOpCost( diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h index 67388984bb3e3..3135958ba4b50 100644 --- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h +++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h @@ -127,11 +127,9 @@ class HexagonTTIImpl final : public BasicTTIImplBase { ArrayRef Mask, TTI::TargetCostKind CostKind, int Index, VectorType *SubTp, ArrayRef Args = {}, const Instruction *CxtI = nullptr) const override; - InstructionCost getGatherScatterOpCost(unsigned Opcode, Type *DataTy, - const Value *Ptr, bool VariableMask, - Align Alignment, - TTI::TargetCostKind CostKind, - const Instruction *I) const override; + InstructionCost + getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA, + TTI::TargetCostKind CostKind) const override; InstructionCost getInterleavedMemoryOpCost( unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef Indices, Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind, diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp index 4788a428d7e64..64ed444b1a805 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp @@ -1120,19 +1120,24 @@ InstructionCost RISCVTTIImpl::getInterleavedMemoryOpCost( return MemCost + ShuffleCost; } -InstructionCost RISCVTTIImpl::getGatherScatterOpCost( - unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask, - Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const { +InstructionCost +RISCVTTIImpl::getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA, + TTI::TargetCostKind CostKind) const { + + bool IsLoad = MICA.getID() == Intrinsic::masked_gather || + MICA.getID() == Intrinsic::vp_gather; + unsigned Opcode = IsLoad ? Instruction::Load : Instruction::Store; + Type *DataTy = MICA.getDataType(); + Align Alignment = MICA.getAlignment(); + const Instruction *I = MICA.getInst(); if (CostKind != TTI::TCK_RecipThroughput) - return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask, - Alignment, CostKind, I); + return BaseT::getGatherScatterOpCost(MICA, CostKind); if ((Opcode == Instruction::Load && !isLegalMaskedGather(DataTy, Align(Alignment))) || (Opcode == Instruction::Store && !isLegalMaskedScatter(DataTy, Align(Alignment)))) - return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask, - Alignment, CostKind, I); + return BaseT::getGatherScatterOpCost(MICA, CostKind); // Cost is proportional to the number of memory operations implied. For // scalable vectors, we use an estimate on that number since we don't diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h index 5efa330b3ad71..b5822a139cc36 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h @@ -190,11 +190,9 @@ class RISCVTTIImpl final : public BasicTTIImplBase { Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind, bool UseMaskForCond = false, bool UseMaskForGaps = false) const override; - InstructionCost getGatherScatterOpCost(unsigned Opcode, Type *DataTy, - const Value *Ptr, bool VariableMask, - Align Alignment, - TTI::TargetCostKind CostKind, - const Instruction *I) const override; + InstructionCost + getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA, + TTI::TargetCostKind CostKind) const override; InstructionCost getExpandCompressMemoryOpCost(const MemIntrinsicCostAttributes &MICA, diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp index a9dc96b53d530..aa75d2c60803e 100644 --- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp +++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp @@ -6258,10 +6258,15 @@ InstructionCost X86TTIImpl::getGSVectorCost(unsigned Opcode, } /// Calculate the cost of Gather / Scatter operation -InstructionCost X86TTIImpl::getGatherScatterOpCost( - unsigned Opcode, Type *SrcVTy, const Value *Ptr, bool VariableMask, - Align Alignment, TTI::TargetCostKind CostKind, - const Instruction *I = nullptr) const { +InstructionCost +X86TTIImpl::getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA, + TTI::TargetCostKind CostKind) const { + bool IsLoad = MICA.getID() == Intrinsic::masked_gather || + MICA.getID() == Intrinsic::vp_gather; + unsigned Opcode = IsLoad ? Instruction::Load : Instruction::Store; + Type *SrcVTy = MICA.getDataType(); + const Value *Ptr = MICA.getPointer(); + Align Alignment = MICA.getAlignment(); if ((Opcode == Instruction::Load && (!isLegalMaskedGather(SrcVTy, Align(Alignment)) || forceScalarizeMaskedGather(cast(SrcVTy), @@ -6270,8 +6275,7 @@ InstructionCost X86TTIImpl::getGatherScatterOpCost( (!isLegalMaskedScatter(SrcVTy, Align(Alignment)) || forceScalarizeMaskedScatter(cast(SrcVTy), Align(Alignment))))) - return BaseT::getGatherScatterOpCost(Opcode, SrcVTy, Ptr, VariableMask, - Alignment, CostKind, I); + return BaseT::getGatherScatterOpCost(MICA, CostKind); assert(SrcVTy->isVectorTy() && "Unexpected data type for Gather/Scatter"); PointerType *PtrTy = dyn_cast(Ptr->getType()); diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h index d6dea9427990b..d35911965d8b5 100644 --- a/llvm/lib/Target/X86/X86TargetTransformInfo.h +++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h @@ -185,11 +185,9 @@ class X86TTIImpl final : public BasicTTIImplBase { InstructionCost getMaskedMemoryOpCost(const MemIntrinsicCostAttributes &MICA, TTI::TargetCostKind CostKind) const override; - InstructionCost getGatherScatterOpCost(unsigned Opcode, Type *DataTy, - const Value *Ptr, bool VariableMask, - Align Alignment, - TTI::TargetCostKind CostKind, - const Instruction *I) const override; + InstructionCost + getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA, + TTI::TargetCostKind CostKind) const override; InstructionCost getPointersChainCost(ArrayRef Ptrs, const Value *Base, const TTI::PointersChainInfo &Info, Type *AccessTy,