Skip to content

Conversation

@peterbell10
Copy link
Contributor

@peterbell10 peterbell10 commented Feb 20, 2025

We've observed that the SLPVectorizer is too conservative on NVPTX because it over-estimates the cost to build a vector. PTX has a single mov instruction that can build e.g. <2 x half> vectors from scalars, however the SLPVectorizer over-estimates it as the cost of 2 insert elements.

To fix this I customize getScalarizationOverhead to lower the cost for building 2x16 types.

@github-actions
Copy link

github-actions bot commented Feb 20, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

if ((!Root && allConstant(VL)) || all_of(VL, IsaPred<UndefValue>))
return TTI::TCC_Free;
auto *VecTy = getWidenedType(ScalarTy, VL.size());
if (auto Cost = TTI.getBuildVectorCost(VecTy, VL, CostKind);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be happy to hear alternatives to use the existing getVectorInstrCost interface better. I had tried pattern matching InsertElt(InsertElt(poison, a, 0), b, 1) but found that the second insert is also called with poison, so we can't pattern match it. e.g. here both calls are passed poison:

InstructionCost InsertFirstCost = TTI->getVectorInstrCost(
Instruction::InsertElement, Ty, TTI::TCK_RecipThroughput, 0,
PoisonValue::get(Ty), *It);
InstructionCost InsertIdxCost = TTI->getVectorInstrCost(
Instruction::InsertElement, Ty, TTI::TCK_RecipThroughput, Idx,
PoisonValue::get(Ty), *It);

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should not be here. Rework getScalarizationOverhead instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that's exactly what I was looking for!

@peterbell10 peterbell10 marked this pull request as ready for review February 21, 2025 16:22
@llvmbot llvmbot added vectorizers backend:NVPTX llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms labels Feb 21, 2025
@llvmbot
Copy link
Member

llvmbot commented Feb 21, 2025

@llvm/pr-subscribers-llvm-transforms
@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-vectorizers

Author: None (peterbell10)

Changes

We've observed that the SLPVectorizer is too conservative on NVPTX because it over-estimates the cost to build a vector. PTX has a single mov instruction that can build e.g. &lt;2 x half&gt; vectors from scalars, however the SLPVectorizer over-estimates it as the cost of 2 insert elements.

To fix this I add TargetTransformInfo::getBuildVectorCost so the target can optionally specify the exact cost.


Full diff: https://github.com/llvm/llvm-project/pull/128077.diff

7 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+16)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+6)
  • (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+6)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+7)
  • (modified) llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h (+22-1)
  • (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+3)
  • (modified) llvm/test/Transforms/SLPVectorizer/NVPTX/v2f16.ll (+110-46)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 9048481b49189..c3a17a6fdb29d 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1479,6 +1479,12 @@ class TargetTransformInfo {
   InstructionCost getInsertExtractValueCost(unsigned Opcode,
                                             TTI::TargetCostKind CostKind) const;
 
+  /// \return The cost of ISD::BUILD_VECTOR, or nullopt if the cost should be
+  /// inferred from insert element and shuffle ops.
+  std::optional<InstructionCost>
+  getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
+                     TargetCostKind CostKind) const;
+
   /// \return The cost of replication shuffle of \p VF elements typed \p EltTy
   /// \p ReplicationFactor times.
   ///
@@ -2224,6 +2230,10 @@ class TargetTransformInfo::Concept {
                                              TTI::TargetCostKind CostKind,
                                              unsigned Index) = 0;
 
+  virtual std::optional<InstructionCost>
+  getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
+                     TargetCostKind CostKind) = 0;
+
   virtual InstructionCost
   getReplicationShuffleCost(Type *EltTy, int ReplicationFactor, int VF,
                             const APInt &DemandedDstElts,
@@ -2952,6 +2962,12 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
                                      unsigned Index) override {
     return Impl.getVectorInstrCost(I, Val, CostKind, Index);
   }
+  std::optional<InstructionCost>
+  getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
+                     TTI::TargetCostKind CostKind) override {
+    return Impl.getBuildVectorCost(VecTy, Operands, CostKind);
+  }
+
   InstructionCost
   getReplicationShuffleCost(Type *EltTy, int ReplicationFactor, int VF,
                             const APInt &DemandedDstElts,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index a8d6dd18266bb..f7ef03bea2548 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -739,6 +739,12 @@ class TargetTransformInfoImplBase {
     return 1;
   }
 
+  std::optional<InstructionCost>
+  getBuildVectorCost(VectorType *Val, ArrayRef<Value *> Operands,
+                     TTI::TargetCostKind CostKind) const {
+    return std::nullopt;
+  }
+
   unsigned getReplicationShuffleCost(Type *EltTy, int ReplicationFactor, int VF,
                                      const APInt &DemandedDstElts,
                                      TTI::TargetCostKind CostKind) {
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 032c7d7b5159e..a58c8dcee49c2 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -1373,6 +1373,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
                                        Op1);
   }
 
+  std::optional<InstructionCost>
+  getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
+                     TTI::TargetCostKind CostKind) {
+    return std::nullopt;
+  }
+
   InstructionCost getReplicationShuffleCost(Type *EltTy, int ReplicationFactor,
                                             int VF,
                                             const APInt &DemandedDstElts,
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 1ca9a16b18112..69b8f6706b563 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1123,6 +1123,13 @@ InstructionCost TargetTransformInfo::getInsertExtractValueCost(
   return Cost;
 }
 
+std::optional<InstructionCost>
+TargetTransformInfo::getBuildVectorCost(VectorType *VecTy,
+                                        ArrayRef<Value *> Operands,
+                                        TargetCostKind CostKind) const {
+  return TTIImpl->getBuildVectorCost(VecTy, Operands, CostKind);
+}
+
 InstructionCost TargetTransformInfo::getReplicationShuffleCost(
     Type *EltTy, int ReplicationFactor, int VF, const APInt &DemandedDstElts,
     TTI::TargetCostKind CostKind) const {
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index b0a846a9c7f96..cbe94a2e82279 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -16,8 +16,9 @@
 #ifndef LLVM_LIB_TARGET_NVPTX_NVPTXTARGETTRANSFORMINFO_H
 #define LLVM_LIB_TARGET_NVPTX_NVPTXTARGETTRANSFORMINFO_H
 
-#include "NVPTXTargetMachine.h"
 #include "MCTargetDesc/NVPTXBaseInfo.h"
+#include "NVPTXTargetMachine.h"
+#include "NVPTXUtilities.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/CodeGen/BasicTTIImpl.h"
 #include "llvm/CodeGen/TargetLowering.h"
@@ -100,6 +101,26 @@ class NVPTXTTIImpl : public BasicTTIImplBase<NVPTXTTIImpl> {
       TTI::OperandValueInfo Op2Info = {TTI::OK_AnyValue, TTI::OP_None},
       ArrayRef<const Value *> Args = {}, const Instruction *CxtI = nullptr);
 
+  std::optional<InstructionCost>
+  getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
+                     TTI::TargetCostKind CostKind) {
+    if (CostKind != TTI::TCK_RecipThroughput)
+      return std::nullopt;
+    auto VT = getTLI()->getValueType(DL, VecTy);
+    if (all_of(Operands, [](Value *Op) { return isa<Constant>(Op); }))
+      return TTI::TCC_Free;
+    if (Isv2x16VT(VT))
+      return 1; // Single vector mov
+    if (VT == MVT::v4i8) {
+      InstructionCost Cost = 3; // 3 x PRMT
+      for (auto *Op : Operands)
+        if (!isa<Constant>(Op))
+          Cost += 1; // zext operand to i32
+      return Cost;
+    }
+    return std::nullopt;
+  }
+
   void getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
                                TTI::UnrollingPreferences &UP,
                                OptimizationRemarkEmitter *ORE);
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index f2aa0e8328585..cfe7bbc641906 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -10203,6 +10203,9 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
     if ((!Root && allConstant(VL)) || all_of(VL, IsaPred<UndefValue>))
       return TTI::TCC_Free;
     auto *VecTy = getWidenedType(ScalarTy, VL.size());
+    if (auto Cost = TTI.getBuildVectorCost(VecTy, VL, CostKind);
+        Cost.has_value())
+      return *Cost;
     InstructionCost GatherCost = 0;
     SmallVector<Value *> Gathers(VL);
     if (!Root && isSplat(VL)) {
diff --git a/llvm/test/Transforms/SLPVectorizer/NVPTX/v2f16.ll b/llvm/test/Transforms/SLPVectorizer/NVPTX/v2f16.ll
index 13773bf901b9b..c74909d7ceb2a 100644
--- a/llvm/test/Transforms/SLPVectorizer/NVPTX/v2f16.ll
+++ b/llvm/test/Transforms/SLPVectorizer/NVPTX/v2f16.ll
@@ -1,59 +1,123 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
-; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_70 | FileCheck %s
-; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_40 | FileCheck %s -check-prefix=NOVECTOR
+; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 | FileCheck %s -check-prefixes=VECTOR,SM90
+; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_80 | FileCheck %s -check-prefixes=VECTOR,SM80
+; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_70 | FileCheck %s -check-prefixes=VECTOR,SM70
+; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_50 | FileCheck %s -check-prefixes=NOVECTOR,SM50
 
 define void @fusion(ptr noalias nocapture align 256 dereferenceable(19267584) %arg, ptr noalias nocapture readonly align 256 dereferenceable(19267584) %arg1, i32 %arg2, i32 %arg3) local_unnamed_addr #0 {
-; CHECK-LABEL: @fusion(
-; CHECK-NEXT:    [[TMP:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6
-; CHECK-NEXT:    [[TMP4:%.*]] = or i32 [[TMP]], [[ARG3:%.*]]
-; CHECK-NEXT:    [[TMP5:%.*]] = shl nuw nsw i32 [[TMP4]], 2
-; CHECK-NEXT:    [[TMP6:%.*]] = zext i32 [[TMP5]] to i64
-; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP6]]
-; CHECK-NEXT:    [[TMP16:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP6]]
-; CHECK-NEXT:    [[TMP1:%.*]] = load <2 x half>, ptr [[TMP11]], align 8
-; CHECK-NEXT:    [[TMP2:%.*]] = fmul fast <2 x half> [[TMP1]], splat (half 0xH5380)
-; CHECK-NEXT:    [[TMP3:%.*]] = fadd fast <2 x half> [[TMP2]], splat (half 0xH57F0)
-; CHECK-NEXT:    store <2 x half> [[TMP3]], ptr [[TMP16]], align 8
-; CHECK-NEXT:    ret void
+; VECTOR-LABEL: @fusion(
+; VECTOR-NEXT:    [[TMP1:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6
+; VECTOR-NEXT:    [[TMP2:%.*]] = or i32 [[TMP1]], [[ARG3:%.*]]
+; VECTOR-NEXT:    [[TMP3:%.*]] = shl nuw nsw i32 [[TMP2]], 2
+; VECTOR-NEXT:    [[TMP4:%.*]] = zext i32 [[TMP3]] to i64
+; VECTOR-NEXT:    [[TMP5:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP4]]
+; VECTOR-NEXT:    [[TMP6:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP4]]
+; VECTOR-NEXT:    [[TMP7:%.*]] = load <2 x half>, ptr [[TMP5]], align 8
+; VECTOR-NEXT:    [[TMP8:%.*]] = fmul fast <2 x half> [[TMP7]], splat (half 0xH5380)
+; VECTOR-NEXT:    [[TMP9:%.*]] = fadd fast <2 x half> [[TMP8]], splat (half 0xH57F0)
+; VECTOR-NEXT:    store <2 x half> [[TMP9]], ptr [[TMP6]], align 8
+; VECTOR-NEXT:    ret void
 ;
 ; NOVECTOR-LABEL: @fusion(
-; NOVECTOR-NEXT:    [[TMP:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6
-; NOVECTOR-NEXT:    [[TMP4:%.*]] = or i32 [[TMP]], [[ARG3:%.*]]
-; NOVECTOR-NEXT:    [[TMP5:%.*]] = shl nuw nsw i32 [[TMP4]], 2
-; NOVECTOR-NEXT:    [[TMP6:%.*]] = zext i32 [[TMP5]] to i64
-; NOVECTOR-NEXT:    [[TMP7:%.*]] = or disjoint i64 [[TMP6]], 1
-; NOVECTOR-NEXT:    [[TMP11:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP6]]
-; NOVECTOR-NEXT:    [[TMP12:%.*]] = load half, ptr [[TMP11]], align 8
+; NOVECTOR-NEXT:    [[TMP1:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6
+; NOVECTOR-NEXT:    [[TMP2:%.*]] = or i32 [[TMP1]], [[ARG3:%.*]]
+; NOVECTOR-NEXT:    [[TMP3:%.*]] = shl nuw nsw i32 [[TMP2]], 2
+; NOVECTOR-NEXT:    [[TMP4:%.*]] = zext i32 [[TMP3]] to i64
+; NOVECTOR-NEXT:    [[TMP10:%.*]] = or disjoint i64 [[TMP4]], 1
+; NOVECTOR-NEXT:    [[TMP5:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP4]]
+; NOVECTOR-NEXT:    [[TMP7:%.*]] = load half, ptr [[TMP5]], align 8
+; NOVECTOR-NEXT:    [[TMP8:%.*]] = fmul fast half [[TMP7]], 0xH5380
+; NOVECTOR-NEXT:    [[TMP9:%.*]] = fadd fast half [[TMP8]], 0xH57F0
+; NOVECTOR-NEXT:    [[TMP6:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP4]]
+; NOVECTOR-NEXT:    store half [[TMP9]], ptr [[TMP6]], align 8
+; NOVECTOR-NEXT:    [[TMP11:%.*]] = getelementptr inbounds half, ptr [[ARG1]], i64 [[TMP10]]
+; NOVECTOR-NEXT:    [[TMP12:%.*]] = load half, ptr [[TMP11]], align 2
 ; NOVECTOR-NEXT:    [[TMP13:%.*]] = fmul fast half [[TMP12]], 0xH5380
 ; NOVECTOR-NEXT:    [[TMP14:%.*]] = fadd fast half [[TMP13]], 0xH57F0
-; NOVECTOR-NEXT:    [[TMP16:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP6]]
-; NOVECTOR-NEXT:    store half [[TMP14]], ptr [[TMP16]], align 8
-; NOVECTOR-NEXT:    [[TMP17:%.*]] = getelementptr inbounds half, ptr [[ARG1]], i64 [[TMP7]]
-; NOVECTOR-NEXT:    [[TMP18:%.*]] = load half, ptr [[TMP17]], align 2
-; NOVECTOR-NEXT:    [[TMP19:%.*]] = fmul fast half [[TMP18]], 0xH5380
-; NOVECTOR-NEXT:    [[TMP20:%.*]] = fadd fast half [[TMP19]], 0xH57F0
-; NOVECTOR-NEXT:    [[TMP21:%.*]] = getelementptr inbounds half, ptr [[ARG]], i64 [[TMP7]]
-; NOVECTOR-NEXT:    store half [[TMP20]], ptr [[TMP21]], align 2
+; NOVECTOR-NEXT:    [[TMP15:%.*]] = getelementptr inbounds half, ptr [[ARG]], i64 [[TMP10]]
+; NOVECTOR-NEXT:    store half [[TMP14]], ptr [[TMP15]], align 2
 ; NOVECTOR-NEXT:    ret void
 ;
-  %tmp = shl nuw nsw i32 %arg2, 6
-  %tmp4 = or i32 %tmp, %arg3
-  %tmp5 = shl nuw nsw i32 %tmp4, 2
-  %tmp6 = zext i32 %tmp5 to i64
-  %tmp7 = or disjoint i64 %tmp6, 1
-  %tmp11 = getelementptr inbounds half, ptr %arg1, i64 %tmp6
-  %tmp12 = load half, ptr %tmp11, align 8
-  %tmp13 = fmul fast half %tmp12, 0xH5380
-  %tmp14 = fadd fast half %tmp13, 0xH57F0
-  %tmp16 = getelementptr inbounds half, ptr %arg, i64 %tmp6
-  store half %tmp14, ptr %tmp16, align 8
-  %tmp17 = getelementptr inbounds half, ptr %arg1, i64 %tmp7
-  %tmp18 = load half, ptr %tmp17, align 2
-  %tmp19 = fmul fast half %tmp18, 0xH5380
-  %tmp20 = fadd fast half %tmp19, 0xH57F0
-  %tmp21 = getelementptr inbounds half, ptr %arg, i64 %tmp7
-  store half %tmp20, ptr %tmp21, align 2
+  %1 = shl nuw nsw i32 %arg2, 6
+  %4 = or i32 %1, %arg3
+  %5 = shl nuw nsw i32 %4, 2
+  %6 = zext i32 %5 to i64
+  %7 = or disjoint i64 %6, 1
+  %11 = getelementptr inbounds half, ptr %arg1, i64 %6
+  %12 = load half, ptr %11, align 8
+  %13 = fmul fast half %12, 0xH5380
+  %14 = fadd fast half %13, 0xH57F0
+  %16 = getelementptr inbounds half, ptr %arg, i64 %6
+  store half %14, ptr %16, align 8
+  %17 = getelementptr inbounds half, ptr %arg1, i64 %7
+  %18 = load half, ptr %17, align 2
+  %19 = fmul fast half %18, 0xH5380
+  %20 = fadd fast half %19, 0xH57F0
+  %21 = getelementptr inbounds half, ptr %arg, i64 %7
+  store half %20, ptr %21, align 2
   ret void
 }
 
+define ptx_kernel void @add_f16(ptr addrspace(1) %0, { half, half } %1, { half, half } %2) {
+; VECTOR-LABEL: @add_f16(
+; VECTOR-NEXT:    [[TMP4:%.*]] = extractvalue { half, half } [[TMP1:%.*]], 0
+; VECTOR-NEXT:    [[TMP5:%.*]] = extractvalue { half, half } [[TMP1]], 1
+; VECTOR-NEXT:    [[TMP6:%.*]] = extractvalue { half, half } [[TMP2:%.*]], 0
+; VECTOR-NEXT:    [[TMP7:%.*]] = extractvalue { half, half } [[TMP2]], 1
+; VECTOR-NEXT:    [[TMP8:%.*]] = insertelement <2 x half> poison, half [[TMP4]], i32 0
+; VECTOR-NEXT:    [[TMP9:%.*]] = insertelement <2 x half> [[TMP8]], half [[TMP5]], i32 1
+; VECTOR-NEXT:    [[TMP10:%.*]] = insertelement <2 x half> poison, half [[TMP6]], i32 0
+; VECTOR-NEXT:    [[TMP11:%.*]] = insertelement <2 x half> [[TMP10]], half [[TMP7]], i32 1
+; VECTOR-NEXT:    [[TMP12:%.*]] = fadd <2 x half> [[TMP9]], [[TMP11]]
+; VECTOR-NEXT:    [[TMP13:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+; VECTOR-NEXT:    [[TMP14:%.*]] = shl i32 [[TMP13]], 1
+; VECTOR-NEXT:    [[TMP15:%.*]] = and i32 [[TMP14]], 62
+; VECTOR-NEXT:    [[TMP16:%.*]] = zext nneg i32 [[TMP15]] to i64
+; VECTOR-NEXT:    [[TMP17:%.*]] = getelementptr half, ptr addrspace(1) [[TMP0:%.*]], i64 [[TMP16]]
+; VECTOR-NEXT:    store <2 x half> [[TMP12]], ptr addrspace(1) [[TMP17]], align 4
+; VECTOR-NEXT:    ret void
+;
+; NOVECTOR-LABEL: @add_f16(
+; NOVECTOR-NEXT:    [[TMP4:%.*]] = extractvalue { half, half } [[TMP1:%.*]], 0
+; NOVECTOR-NEXT:    [[TMP5:%.*]] = extractvalue { half, half } [[TMP1]], 1
+; NOVECTOR-NEXT:    [[TMP6:%.*]] = extractvalue { half, half } [[TMP2:%.*]], 0
+; NOVECTOR-NEXT:    [[TMP7:%.*]] = extractvalue { half, half } [[TMP2]], 1
+; NOVECTOR-NEXT:    [[TMP8:%.*]] = fadd half [[TMP4]], [[TMP6]]
+; NOVECTOR-NEXT:    [[TMP9:%.*]] = fadd half [[TMP5]], [[TMP7]]
+; NOVECTOR-NEXT:    [[TMP13:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+; NOVECTOR-NEXT:    [[TMP14:%.*]] = shl i32 [[TMP13]], 1
+; NOVECTOR-NEXT:    [[TMP15:%.*]] = and i32 [[TMP14]], 62
+; NOVECTOR-NEXT:    [[TMP16:%.*]] = zext nneg i32 [[TMP15]] to i64
+; NOVECTOR-NEXT:    [[TMP17:%.*]] = getelementptr half, ptr addrspace(1) [[TMP0:%.*]], i64 [[TMP16]]
+; NOVECTOR-NEXT:    [[TMP19:%.*]] = insertelement <2 x half> poison, half [[TMP8]], i64 0
+; NOVECTOR-NEXT:    [[TMP12:%.*]] = insertelement <2 x half> [[TMP19]], half [[TMP9]], i64 1
+; NOVECTOR-NEXT:    store <2 x half> [[TMP12]], ptr addrspace(1) [[TMP17]], align 4
+; NOVECTOR-NEXT:    ret void
+;
+  %5 = extractvalue { half, half } %1, 0
+  %6 = extractvalue { half, half } %1, 1
+  %7 = extractvalue { half, half } %2, 0
+  %8 = extractvalue { half, half } %2, 1
+  %9 = fadd half %5, %7
+  %10 = fadd half %6, %8
+  %11 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+  %12 = shl i32 %11, 1
+  %13 = and i32 %12, 62
+  %14 = zext nneg i32 %13 to i64
+  %15 = getelementptr half, ptr addrspace(1) %0, i64 %14
+  %18 = insertelement <2 x half> poison, half %9, i64 0
+  %19 = insertelement <2 x half> %18, half %10, i64 1
+  store <2 x half> %19, ptr addrspace(1) %15, align 4
+  ret void
+}
+
+; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1
+
 attributes #0 = { nounwind }
+attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
+;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
+; SM50: {{.*}}
+; SM70: {{.*}}
+; SM80: {{.*}}
+; SM90: {{.*}}

@llvmbot
Copy link
Member

llvmbot commented Feb 21, 2025

@llvm/pr-subscribers-backend-nvptx

Author: None (peterbell10)

Changes

We've observed that the SLPVectorizer is too conservative on NVPTX because it over-estimates the cost to build a vector. PTX has a single mov instruction that can build e.g. &lt;2 x half&gt; vectors from scalars, however the SLPVectorizer over-estimates it as the cost of 2 insert elements.

To fix this I add TargetTransformInfo::getBuildVectorCost so the target can optionally specify the exact cost.


Full diff: https://github.com/llvm/llvm-project/pull/128077.diff

7 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+16)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+6)
  • (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+6)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+7)
  • (modified) llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h (+22-1)
  • (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+3)
  • (modified) llvm/test/Transforms/SLPVectorizer/NVPTX/v2f16.ll (+110-46)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 9048481b49189..c3a17a6fdb29d 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1479,6 +1479,12 @@ class TargetTransformInfo {
   InstructionCost getInsertExtractValueCost(unsigned Opcode,
                                             TTI::TargetCostKind CostKind) const;
 
+  /// \return The cost of ISD::BUILD_VECTOR, or nullopt if the cost should be
+  /// inferred from insert element and shuffle ops.
+  std::optional<InstructionCost>
+  getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
+                     TargetCostKind CostKind) const;
+
   /// \return The cost of replication shuffle of \p VF elements typed \p EltTy
   /// \p ReplicationFactor times.
   ///
@@ -2224,6 +2230,10 @@ class TargetTransformInfo::Concept {
                                              TTI::TargetCostKind CostKind,
                                              unsigned Index) = 0;
 
+  virtual std::optional<InstructionCost>
+  getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
+                     TargetCostKind CostKind) = 0;
+
   virtual InstructionCost
   getReplicationShuffleCost(Type *EltTy, int ReplicationFactor, int VF,
                             const APInt &DemandedDstElts,
@@ -2952,6 +2962,12 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
                                      unsigned Index) override {
     return Impl.getVectorInstrCost(I, Val, CostKind, Index);
   }
+  std::optional<InstructionCost>
+  getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
+                     TTI::TargetCostKind CostKind) override {
+    return Impl.getBuildVectorCost(VecTy, Operands, CostKind);
+  }
+
   InstructionCost
   getReplicationShuffleCost(Type *EltTy, int ReplicationFactor, int VF,
                             const APInt &DemandedDstElts,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index a8d6dd18266bb..f7ef03bea2548 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -739,6 +739,12 @@ class TargetTransformInfoImplBase {
     return 1;
   }
 
+  std::optional<InstructionCost>
+  getBuildVectorCost(VectorType *Val, ArrayRef<Value *> Operands,
+                     TTI::TargetCostKind CostKind) const {
+    return std::nullopt;
+  }
+
   unsigned getReplicationShuffleCost(Type *EltTy, int ReplicationFactor, int VF,
                                      const APInt &DemandedDstElts,
                                      TTI::TargetCostKind CostKind) {
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 032c7d7b5159e..a58c8dcee49c2 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -1373,6 +1373,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
                                        Op1);
   }
 
+  std::optional<InstructionCost>
+  getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
+                     TTI::TargetCostKind CostKind) {
+    return std::nullopt;
+  }
+
   InstructionCost getReplicationShuffleCost(Type *EltTy, int ReplicationFactor,
                                             int VF,
                                             const APInt &DemandedDstElts,
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 1ca9a16b18112..69b8f6706b563 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1123,6 +1123,13 @@ InstructionCost TargetTransformInfo::getInsertExtractValueCost(
   return Cost;
 }
 
+std::optional<InstructionCost>
+TargetTransformInfo::getBuildVectorCost(VectorType *VecTy,
+                                        ArrayRef<Value *> Operands,
+                                        TargetCostKind CostKind) const {
+  return TTIImpl->getBuildVectorCost(VecTy, Operands, CostKind);
+}
+
 InstructionCost TargetTransformInfo::getReplicationShuffleCost(
     Type *EltTy, int ReplicationFactor, int VF, const APInt &DemandedDstElts,
     TTI::TargetCostKind CostKind) const {
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index b0a846a9c7f96..cbe94a2e82279 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -16,8 +16,9 @@
 #ifndef LLVM_LIB_TARGET_NVPTX_NVPTXTARGETTRANSFORMINFO_H
 #define LLVM_LIB_TARGET_NVPTX_NVPTXTARGETTRANSFORMINFO_H
 
-#include "NVPTXTargetMachine.h"
 #include "MCTargetDesc/NVPTXBaseInfo.h"
+#include "NVPTXTargetMachine.h"
+#include "NVPTXUtilities.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/CodeGen/BasicTTIImpl.h"
 #include "llvm/CodeGen/TargetLowering.h"
@@ -100,6 +101,26 @@ class NVPTXTTIImpl : public BasicTTIImplBase<NVPTXTTIImpl> {
       TTI::OperandValueInfo Op2Info = {TTI::OK_AnyValue, TTI::OP_None},
       ArrayRef<const Value *> Args = {}, const Instruction *CxtI = nullptr);
 
+  std::optional<InstructionCost>
+  getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
+                     TTI::TargetCostKind CostKind) {
+    if (CostKind != TTI::TCK_RecipThroughput)
+      return std::nullopt;
+    auto VT = getTLI()->getValueType(DL, VecTy);
+    if (all_of(Operands, [](Value *Op) { return isa<Constant>(Op); }))
+      return TTI::TCC_Free;
+    if (Isv2x16VT(VT))
+      return 1; // Single vector mov
+    if (VT == MVT::v4i8) {
+      InstructionCost Cost = 3; // 3 x PRMT
+      for (auto *Op : Operands)
+        if (!isa<Constant>(Op))
+          Cost += 1; // zext operand to i32
+      return Cost;
+    }
+    return std::nullopt;
+  }
+
   void getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
                                TTI::UnrollingPreferences &UP,
                                OptimizationRemarkEmitter *ORE);
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index f2aa0e8328585..cfe7bbc641906 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -10203,6 +10203,9 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
     if ((!Root && allConstant(VL)) || all_of(VL, IsaPred<UndefValue>))
       return TTI::TCC_Free;
     auto *VecTy = getWidenedType(ScalarTy, VL.size());
+    if (auto Cost = TTI.getBuildVectorCost(VecTy, VL, CostKind);
+        Cost.has_value())
+      return *Cost;
     InstructionCost GatherCost = 0;
     SmallVector<Value *> Gathers(VL);
     if (!Root && isSplat(VL)) {
diff --git a/llvm/test/Transforms/SLPVectorizer/NVPTX/v2f16.ll b/llvm/test/Transforms/SLPVectorizer/NVPTX/v2f16.ll
index 13773bf901b9b..c74909d7ceb2a 100644
--- a/llvm/test/Transforms/SLPVectorizer/NVPTX/v2f16.ll
+++ b/llvm/test/Transforms/SLPVectorizer/NVPTX/v2f16.ll
@@ -1,59 +1,123 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
-; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_70 | FileCheck %s
-; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_40 | FileCheck %s -check-prefix=NOVECTOR
+; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 | FileCheck %s -check-prefixes=VECTOR,SM90
+; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_80 | FileCheck %s -check-prefixes=VECTOR,SM80
+; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_70 | FileCheck %s -check-prefixes=VECTOR,SM70
+; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_50 | FileCheck %s -check-prefixes=NOVECTOR,SM50
 
 define void @fusion(ptr noalias nocapture align 256 dereferenceable(19267584) %arg, ptr noalias nocapture readonly align 256 dereferenceable(19267584) %arg1, i32 %arg2, i32 %arg3) local_unnamed_addr #0 {
-; CHECK-LABEL: @fusion(
-; CHECK-NEXT:    [[TMP:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6
-; CHECK-NEXT:    [[TMP4:%.*]] = or i32 [[TMP]], [[ARG3:%.*]]
-; CHECK-NEXT:    [[TMP5:%.*]] = shl nuw nsw i32 [[TMP4]], 2
-; CHECK-NEXT:    [[TMP6:%.*]] = zext i32 [[TMP5]] to i64
-; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP6]]
-; CHECK-NEXT:    [[TMP16:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP6]]
-; CHECK-NEXT:    [[TMP1:%.*]] = load <2 x half>, ptr [[TMP11]], align 8
-; CHECK-NEXT:    [[TMP2:%.*]] = fmul fast <2 x half> [[TMP1]], splat (half 0xH5380)
-; CHECK-NEXT:    [[TMP3:%.*]] = fadd fast <2 x half> [[TMP2]], splat (half 0xH57F0)
-; CHECK-NEXT:    store <2 x half> [[TMP3]], ptr [[TMP16]], align 8
-; CHECK-NEXT:    ret void
+; VECTOR-LABEL: @fusion(
+; VECTOR-NEXT:    [[TMP1:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6
+; VECTOR-NEXT:    [[TMP2:%.*]] = or i32 [[TMP1]], [[ARG3:%.*]]
+; VECTOR-NEXT:    [[TMP3:%.*]] = shl nuw nsw i32 [[TMP2]], 2
+; VECTOR-NEXT:    [[TMP4:%.*]] = zext i32 [[TMP3]] to i64
+; VECTOR-NEXT:    [[TMP5:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP4]]
+; VECTOR-NEXT:    [[TMP6:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP4]]
+; VECTOR-NEXT:    [[TMP7:%.*]] = load <2 x half>, ptr [[TMP5]], align 8
+; VECTOR-NEXT:    [[TMP8:%.*]] = fmul fast <2 x half> [[TMP7]], splat (half 0xH5380)
+; VECTOR-NEXT:    [[TMP9:%.*]] = fadd fast <2 x half> [[TMP8]], splat (half 0xH57F0)
+; VECTOR-NEXT:    store <2 x half> [[TMP9]], ptr [[TMP6]], align 8
+; VECTOR-NEXT:    ret void
 ;
 ; NOVECTOR-LABEL: @fusion(
-; NOVECTOR-NEXT:    [[TMP:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6
-; NOVECTOR-NEXT:    [[TMP4:%.*]] = or i32 [[TMP]], [[ARG3:%.*]]
-; NOVECTOR-NEXT:    [[TMP5:%.*]] = shl nuw nsw i32 [[TMP4]], 2
-; NOVECTOR-NEXT:    [[TMP6:%.*]] = zext i32 [[TMP5]] to i64
-; NOVECTOR-NEXT:    [[TMP7:%.*]] = or disjoint i64 [[TMP6]], 1
-; NOVECTOR-NEXT:    [[TMP11:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP6]]
-; NOVECTOR-NEXT:    [[TMP12:%.*]] = load half, ptr [[TMP11]], align 8
+; NOVECTOR-NEXT:    [[TMP1:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6
+; NOVECTOR-NEXT:    [[TMP2:%.*]] = or i32 [[TMP1]], [[ARG3:%.*]]
+; NOVECTOR-NEXT:    [[TMP3:%.*]] = shl nuw nsw i32 [[TMP2]], 2
+; NOVECTOR-NEXT:    [[TMP4:%.*]] = zext i32 [[TMP3]] to i64
+; NOVECTOR-NEXT:    [[TMP10:%.*]] = or disjoint i64 [[TMP4]], 1
+; NOVECTOR-NEXT:    [[TMP5:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP4]]
+; NOVECTOR-NEXT:    [[TMP7:%.*]] = load half, ptr [[TMP5]], align 8
+; NOVECTOR-NEXT:    [[TMP8:%.*]] = fmul fast half [[TMP7]], 0xH5380
+; NOVECTOR-NEXT:    [[TMP9:%.*]] = fadd fast half [[TMP8]], 0xH57F0
+; NOVECTOR-NEXT:    [[TMP6:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP4]]
+; NOVECTOR-NEXT:    store half [[TMP9]], ptr [[TMP6]], align 8
+; NOVECTOR-NEXT:    [[TMP11:%.*]] = getelementptr inbounds half, ptr [[ARG1]], i64 [[TMP10]]
+; NOVECTOR-NEXT:    [[TMP12:%.*]] = load half, ptr [[TMP11]], align 2
 ; NOVECTOR-NEXT:    [[TMP13:%.*]] = fmul fast half [[TMP12]], 0xH5380
 ; NOVECTOR-NEXT:    [[TMP14:%.*]] = fadd fast half [[TMP13]], 0xH57F0
-; NOVECTOR-NEXT:    [[TMP16:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP6]]
-; NOVECTOR-NEXT:    store half [[TMP14]], ptr [[TMP16]], align 8
-; NOVECTOR-NEXT:    [[TMP17:%.*]] = getelementptr inbounds half, ptr [[ARG1]], i64 [[TMP7]]
-; NOVECTOR-NEXT:    [[TMP18:%.*]] = load half, ptr [[TMP17]], align 2
-; NOVECTOR-NEXT:    [[TMP19:%.*]] = fmul fast half [[TMP18]], 0xH5380
-; NOVECTOR-NEXT:    [[TMP20:%.*]] = fadd fast half [[TMP19]], 0xH57F0
-; NOVECTOR-NEXT:    [[TMP21:%.*]] = getelementptr inbounds half, ptr [[ARG]], i64 [[TMP7]]
-; NOVECTOR-NEXT:    store half [[TMP20]], ptr [[TMP21]], align 2
+; NOVECTOR-NEXT:    [[TMP15:%.*]] = getelementptr inbounds half, ptr [[ARG]], i64 [[TMP10]]
+; NOVECTOR-NEXT:    store half [[TMP14]], ptr [[TMP15]], align 2
 ; NOVECTOR-NEXT:    ret void
 ;
-  %tmp = shl nuw nsw i32 %arg2, 6
-  %tmp4 = or i32 %tmp, %arg3
-  %tmp5 = shl nuw nsw i32 %tmp4, 2
-  %tmp6 = zext i32 %tmp5 to i64
-  %tmp7 = or disjoint i64 %tmp6, 1
-  %tmp11 = getelementptr inbounds half, ptr %arg1, i64 %tmp6
-  %tmp12 = load half, ptr %tmp11, align 8
-  %tmp13 = fmul fast half %tmp12, 0xH5380
-  %tmp14 = fadd fast half %tmp13, 0xH57F0
-  %tmp16 = getelementptr inbounds half, ptr %arg, i64 %tmp6
-  store half %tmp14, ptr %tmp16, align 8
-  %tmp17 = getelementptr inbounds half, ptr %arg1, i64 %tmp7
-  %tmp18 = load half, ptr %tmp17, align 2
-  %tmp19 = fmul fast half %tmp18, 0xH5380
-  %tmp20 = fadd fast half %tmp19, 0xH57F0
-  %tmp21 = getelementptr inbounds half, ptr %arg, i64 %tmp7
-  store half %tmp20, ptr %tmp21, align 2
+  %1 = shl nuw nsw i32 %arg2, 6
+  %4 = or i32 %1, %arg3
+  %5 = shl nuw nsw i32 %4, 2
+  %6 = zext i32 %5 to i64
+  %7 = or disjoint i64 %6, 1
+  %11 = getelementptr inbounds half, ptr %arg1, i64 %6
+  %12 = load half, ptr %11, align 8
+  %13 = fmul fast half %12, 0xH5380
+  %14 = fadd fast half %13, 0xH57F0
+  %16 = getelementptr inbounds half, ptr %arg, i64 %6
+  store half %14, ptr %16, align 8
+  %17 = getelementptr inbounds half, ptr %arg1, i64 %7
+  %18 = load half, ptr %17, align 2
+  %19 = fmul fast half %18, 0xH5380
+  %20 = fadd fast half %19, 0xH57F0
+  %21 = getelementptr inbounds half, ptr %arg, i64 %7
+  store half %20, ptr %21, align 2
   ret void
 }
 
+define ptx_kernel void @add_f16(ptr addrspace(1) %0, { half, half } %1, { half, half } %2) {
+; VECTOR-LABEL: @add_f16(
+; VECTOR-NEXT:    [[TMP4:%.*]] = extractvalue { half, half } [[TMP1:%.*]], 0
+; VECTOR-NEXT:    [[TMP5:%.*]] = extractvalue { half, half } [[TMP1]], 1
+; VECTOR-NEXT:    [[TMP6:%.*]] = extractvalue { half, half } [[TMP2:%.*]], 0
+; VECTOR-NEXT:    [[TMP7:%.*]] = extractvalue { half, half } [[TMP2]], 1
+; VECTOR-NEXT:    [[TMP8:%.*]] = insertelement <2 x half> poison, half [[TMP4]], i32 0
+; VECTOR-NEXT:    [[TMP9:%.*]] = insertelement <2 x half> [[TMP8]], half [[TMP5]], i32 1
+; VECTOR-NEXT:    [[TMP10:%.*]] = insertelement <2 x half> poison, half [[TMP6]], i32 0
+; VECTOR-NEXT:    [[TMP11:%.*]] = insertelement <2 x half> [[TMP10]], half [[TMP7]], i32 1
+; VECTOR-NEXT:    [[TMP12:%.*]] = fadd <2 x half> [[TMP9]], [[TMP11]]
+; VECTOR-NEXT:    [[TMP13:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+; VECTOR-NEXT:    [[TMP14:%.*]] = shl i32 [[TMP13]], 1
+; VECTOR-NEXT:    [[TMP15:%.*]] = and i32 [[TMP14]], 62
+; VECTOR-NEXT:    [[TMP16:%.*]] = zext nneg i32 [[TMP15]] to i64
+; VECTOR-NEXT:    [[TMP17:%.*]] = getelementptr half, ptr addrspace(1) [[TMP0:%.*]], i64 [[TMP16]]
+; VECTOR-NEXT:    store <2 x half> [[TMP12]], ptr addrspace(1) [[TMP17]], align 4
+; VECTOR-NEXT:    ret void
+;
+; NOVECTOR-LABEL: @add_f16(
+; NOVECTOR-NEXT:    [[TMP4:%.*]] = extractvalue { half, half } [[TMP1:%.*]], 0
+; NOVECTOR-NEXT:    [[TMP5:%.*]] = extractvalue { half, half } [[TMP1]], 1
+; NOVECTOR-NEXT:    [[TMP6:%.*]] = extractvalue { half, half } [[TMP2:%.*]], 0
+; NOVECTOR-NEXT:    [[TMP7:%.*]] = extractvalue { half, half } [[TMP2]], 1
+; NOVECTOR-NEXT:    [[TMP8:%.*]] = fadd half [[TMP4]], [[TMP6]]
+; NOVECTOR-NEXT:    [[TMP9:%.*]] = fadd half [[TMP5]], [[TMP7]]
+; NOVECTOR-NEXT:    [[TMP13:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+; NOVECTOR-NEXT:    [[TMP14:%.*]] = shl i32 [[TMP13]], 1
+; NOVECTOR-NEXT:    [[TMP15:%.*]] = and i32 [[TMP14]], 62
+; NOVECTOR-NEXT:    [[TMP16:%.*]] = zext nneg i32 [[TMP15]] to i64
+; NOVECTOR-NEXT:    [[TMP17:%.*]] = getelementptr half, ptr addrspace(1) [[TMP0:%.*]], i64 [[TMP16]]
+; NOVECTOR-NEXT:    [[TMP19:%.*]] = insertelement <2 x half> poison, half [[TMP8]], i64 0
+; NOVECTOR-NEXT:    [[TMP12:%.*]] = insertelement <2 x half> [[TMP19]], half [[TMP9]], i64 1
+; NOVECTOR-NEXT:    store <2 x half> [[TMP12]], ptr addrspace(1) [[TMP17]], align 4
+; NOVECTOR-NEXT:    ret void
+;
+  %5 = extractvalue { half, half } %1, 0
+  %6 = extractvalue { half, half } %1, 1
+  %7 = extractvalue { half, half } %2, 0
+  %8 = extractvalue { half, half } %2, 1
+  %9 = fadd half %5, %7
+  %10 = fadd half %6, %8
+  %11 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+  %12 = shl i32 %11, 1
+  %13 = and i32 %12, 62
+  %14 = zext nneg i32 %13 to i64
+  %15 = getelementptr half, ptr addrspace(1) %0, i64 %14
+  %18 = insertelement <2 x half> poison, half %9, i64 0
+  %19 = insertelement <2 x half> %18, half %10, i64 1
+  store <2 x half> %19, ptr addrspace(1) %15, align 4
+  ret void
+}
+
+; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1
+
 attributes #0 = { nounwind }
+attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
+;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
+; SM50: {{.*}}
+; SM70: {{.*}}
+; SM80: {{.*}}
+; SM90: {{.*}}

@peterbell10 peterbell10 requested a review from Artem-B February 21, 2025 16:23
@Artem-B
Copy link
Member

Artem-B commented Feb 21, 2025

PTX has a single mov instruction that can build e.g. <2 x half> vectors from scalars, however the SLPVectorizer over-estimates it as the cost of 2 insert elements.

Single instruction in PTX does not mean that it's efficiently implemented in hardware.

In this particular case, mov.b32 %r1, {%h1, %h2} takes three (different) instructions on GPU:
https://godbolt.org/z/1jWfsjfrz

I think that the current estimate that construction of a v2f16 vector costs us an equivalent of few logical ops is quite reasonable.

@peterbell10
Copy link
Contributor Author

I think your example is a bit misleading because it includes the argument passing convention, if we read the values from gmem instead the mov becomes a single SASS instruction:
https://godbolt.org/z/rKeGTrj96

I am a bit surprised though that ptxas keeps those extra PRMT calls in for the argument passing case. AFAICT it's zeroing out the top 16 bits, then immediately discarding them.

@peterbell10 peterbell10 changed the title [SLPVectorizer][NVPTX] Customize getBuildVectorCost for NVPTX [NVPTX] Customize getScalarizationOverhead Feb 21, 2025
@peterbell10 peterbell10 removed the llvm:analysis Includes value tracking, cost tables and constant folding label Feb 21, 2025
Comment on lines +131 to +137
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very familiar with this API. Could you explain why we need to incur the cost of the zext? Doesn't prmt emit an i32?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It occurs to me that we don't actually need the zero- part we really just need to extend the register type, but unfortunately I can't think of a way to really express that in the ptx. Certainly the current lowering sometimes produces SASS to zero the top part of the register. e.g. see the LOP3 instruction in this:
https://godbolt.org/z/EPjjzv4cz

It seems for the load ptxas knows the top part is zeroed, but if it's unsure it will mask the lower bits.

@Artem-B
Copy link
Member

Artem-B commented Feb 21, 2025

I think your example is a bit misleading because it includes the argument passing convention, if we read the values from gmem instead the mov becomes a single SASS instruction: https://godbolt.org/z/rKeGTrj96

Fair enough. It's still not free as it needs prmt/xmad. I think that cost adjust ment here is a wash. In the end it will end up with about the same instructions on the SASS level, whether we let LLVM construct it by element insertion or by logical ops -- there's just no way to bypass the fact that all registers on SASS level are 32-bit, so we end up shuffling bits, one way or another.

Perhaps I'm missing something? Can you compile the test cases you've added to v2f16.ll to PTX, and see what we get for both variants in both PTX and SASS?

@peterbell10
Copy link
Contributor Author

peterbell10 commented Feb 21, 2025

Fair enough. It's still not free as it needs prmt/xmad. I think that cost adjust ment here is a wash.

The current heuristic results in a cost of 2 to build <2 x half> and on SM90 the test case I added is not currently vectorized, so this is a meaningful change.

Perhaps I'm missing something? Can you compile the test cases you've added to v2f16.ll to PTX, and see what we get for both variants in both PTX and SASS?

Here are the results for sm_90:

PTX for sm_90
//
// Generated by LLVM NVPTX Back-End
//

.version 7.8
.target sm_90
.address_size 64

	// .globl	fusion                  // -- Begin function fusion
                                        // @fusion
.visible .func fusion(
	.param .b64 fusion_param_0,
	.param .b64 fusion_param_1,
	.param .b32 fusion_param_2,
	.param .b32 fusion_param_3
)
{
	.reg .b32 	%r<10>;
	.reg .b64 	%rd<6>;

// %bb.0:
	ld.param.u64 	%rd1, [fusion_param_0];
	ld.param.u64 	%rd2, [fusion_param_1];
	ld.param.u32 	%r1, [fusion_param_2];
	ld.param.u32 	%r2, [fusion_param_3];
	shl.b32 	%r3, %r2, 2;
	shl.b32 	%r4, %r1, 8;
	or.b32  	%r5, %r4, %r3;
	mul.wide.u32 	%rd3, %r5, 2;
	add.s64 	%rd4, %rd2, %rd3;
	add.s64 	%rd5, %rd1, %rd3;
	ld.b32 	%r6, [%rd4];
	mov.b32 	%r7, 1475368944;
	mov.b32 	%r8, 1400918912;
	fma.rn.f16x2 	%r9, %r6, %r8, %r7;
	st.b32 	[%rd5], %r9;
	ret;
                                        // -- End function
}
	// .globl	add_f16                 // -- Begin function add_f16
.visible .entry add_f16(
	.param .u64 .ptr .global .align 1 add_f16_param_0,
	.param .align 2 .b8 add_f16_param_1[4],
	.param .align 2 .b8 add_f16_param_2[4]
)                                       // @add_f16
{
	.reg .b16 	%rs<5>;
	.reg .b32 	%r<7>;
	.reg .b64 	%rd<4>;

// %bb.0:
	ld.param.u64 	%rd1, [add_f16_param_0];
	ld.param.b16 	%rs1, [add_f16_param_1+2];
	ld.param.b16 	%rs2, [add_f16_param_1];
	ld.param.b16 	%rs3, [add_f16_param_2+2];
	ld.param.b16 	%rs4, [add_f16_param_2];
	mov.b32 	%r1, {%rs2, %rs1};
	mov.b32 	%r2, {%rs4, %rs3};
	add.rn.f16x2 	%r3, %r1, %r2;
	mov.u32 	%r4, %tid.x;
	shl.b32 	%r5, %r4, 1;
	and.b32  	%r6, %r5, 62;
	mul.wide.u32 	%rd2, %r6, 2;
	add.s64 	%rd3, %rd1, %rd2;
	st.global.b32 	[%rd3], %r3;
	ret;
                                        // -- End function
}
SASS for sm_90
	code for sm_90
		Function : add_f16
	.headerflags	@"EF_CUDA_TEXMODE_UNIFIED EF_CUDA_64BIT_ADDRESS EF_CUDA_SM90 EF_CUDA_VIRTUAL_SM(EF_CUDA_SM90)"
        /*0000*/                   LDC R1, c[0x0][0x28] ;                  /* 0x00000a00ff017b82 */
                                                                           /* 0x000fe20000000800 */
        /*0010*/                   S2R R0, SR_TID.X ;                      /* 0x0000000000007919 */
                                                                           /* 0x000e6e0000002100 */
        /*0020*/                   LDC.64 R4, c[0x0][0x218] ;              /* 0x00008600ff047b82 */
                                                                           /* 0x000ea20000000a00 */
        /*0030*/                   ULDC.64 UR4, c[0x0][0x208] ;            /* 0x0000820000047ab9 */
                                                                           /* 0x000fce0000000a00 */
        /*0040*/                   LDC.64 R2, c[0x0][0x210] ;              /* 0x00008400ff027b82 */
                                                                           /* 0x000ee20000000a00 */
        /*0050*/                   HFMA2.MMA R5, R4, 1, 1, R5 ;            /* 0x3c003c0004057835 */
                                                                           /* 0x004fe20000000005 */
        /*0060*/                   SHF.L.U32 R0, R0, 0x1, RZ ;             /* 0x0000000100007819 */
                                                                           /* 0x002fc800000006ff */
        /*0070*/                   LOP3.LUT R7, R0, 0x3e, RZ, 0xc0, !PT ;  /* 0x0000003e00077812 */
                                                                           /* 0x000fca00078ec0ff */
        /*0080*/                   IMAD.WIDE.U32 R2, R7, 0x2, R2 ;         /* 0x0000000207027825 */
                                                                           /* 0x008fca00078e0002 */
        /*0090*/                   STG.E desc[UR4][R2.64], R5 ;            /* 0x0000000502007986 */
                                                                           /* 0x000fe2000c101904 */
        /*00a0*/                   EXIT ;                                  /* 0x000000000000794d */
                                                                           /* 0x000fea0003800000 */
        /*00b0*/                   BRA 0xb0;                               /* 0xfffffffc00fc7947 */
                                                                           /* 0x000fc0000383ffff */
        /*00c0*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*00d0*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*00e0*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*00f0*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*0100*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*0110*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*0120*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*0130*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*0140*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*0150*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*0160*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*0170*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
		..........


		Function : fusion
	.headerflags	@"EF_CUDA_TEXMODE_UNIFIED EF_CUDA_64BIT_ADDRESS EF_CUDA_SM90 EF_CUDA_VIRTUAL_SM(EF_CUDA_SM90)"
        /*0000*/                   SHF.L.U32 R9, R9, 0x2, RZ ;             /* 0x0000000209097819 */
                                                                           /* 0x000fe200000006ff */
        /*0010*/                   ULDC.64 UR4, c[0x0][0x208] ;            /* 0x0000820000047ab9 */
                                                                           /* 0x000fe20000000a00 */
        /*0020*/                   SHF.L.U32 R8, R8, 0x8, RZ ;             /* 0x0000000808087819 */
                                                                           /* 0x000fc800000006ff */
        /*0030*/                   LOP3.LUT R9, R8, R9, RZ, 0xfc, !PT ;    /* 0x0000000908097212 */
                                                                           /* 0x000fca00078efcff */
        /*0040*/                   IMAD.WIDE.U32 R6, R9, 0x2, R6 ;         /* 0x0000000209067825 */
                                                                           /* 0x000fcc00078e0006 */
        /*0050*/                   LD.E R6, desc[UR4][R6.64] ;             /* 0x0000000406067980 */
                                                                           /* 0x000ea2000c101900 */
        /*0060*/                   HFMA2.MMA R3, -RZ, RZ, 0, 60 ;          /* 0x00005380ff037435 */
                                                                           /* 0x000fe200000001ff */
        /*0070*/                   IMAD.WIDE.U32 R4, R9, 0x2, R4 ;         /* 0x0000000209047825 */
                                                                           /* 0x000fd200078e0004 */
        /*0080*/                   HFMA2 R3, R6, R3.H0_H0, 127, 127 ;      /* 0x57f057f006037431 */
                                                                           /* 0x004fca0000040003 */
        /*0090*/                   ST.E desc[UR4][R4.64], R3 ;             /* 0x0000000304007985 */
                                                                           /* 0x0001e4000c101904 */
        /*00a0*/                   RET.ABS.NODEC R20 0x0 ;                 /* 0x0000000014007950 */
                                                                           /* 0x001fea0003e00000 */
        /*00b0*/                   BRA 0xb0;                               /* 0xfffffffc00fc7947 */
                                                                           /* 0x000fc0000383ffff */
        /*00c0*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*00d0*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*00e0*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*00f0*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*0100*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*0110*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*0120*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*0130*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*0140*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*0150*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*0160*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */
        /*0170*/                   NOP;                                    /* 0x0000000000007918 */
                                                                           /* 0x000fc00000000000 */

The SASS is a bit confusing, but I think LDC.64 R4, c[0x0][0x218] is loading R4 and R5 directly as f16x2 vector registers in a single instruction, which is quite cool.

@Artem-B
Copy link
Member

Artem-B commented Feb 21, 2025

Here's a more apples-to-apples comparison, where I literally replaced mov.b32 with cvt.b32.b16, shift, or: https://godbolt.org/z/96s7nGGrx

Afaict, the generated code is almost identical, modulo different order of instructions.

@peterbell10
Copy link
Contributor Author

@Artem-B I don't follow your point. Yes, ptxas can optimize the insert-element sequence itself, but that doesn't have any effect on the slp-vectorizer's cost heuristics.

@peterbell10
Copy link
Contributor Author

Or do you mean that we should just rely on ptxas to do all the vectorization for us?

@Artem-B
Copy link
Member

Artem-B commented Feb 21, 2025

Sorry, I've got tunnel-visioned into looking only at the conversion to/from f16x2 vectors only. That part will probably remain the same regardless of how we do it on PTX level.

However, if the change does allow us to vectorize more scalar 16-bit ops, that is useful.

We've observed that the SLPVectorizer is too conservative on NVPTX because it
over-estimates the cost to build a vector. PTX has a single `mov` instruction
that can build <2 x half> vectors from scalars, however the SLPVectorizer
estimates the cost as 2 insert elements.

To fix this I add `TargetTransformInfo::getBuildVectorCost` so the
target can optionally specify the exact cost.
@peterbell10 peterbell10 force-pushed the slp-buildvector-cost branch from e75affb to 52f75fc Compare March 28, 2025 22:07
@peterbell10 peterbell10 merged commit 55430f8 into llvm:main Mar 29, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants