Skip to content

Commit ec973a7

Browse files
committed
[SLPVectorizer][NVPTX] Customize getBuildVectorCost for NVPTX
We've observed that the SLPVectorizer is too conservative on NVPTX because it over-estimates the cost to build a vector. PTX has a single `mov` instruction that can build <2 x half> vectors from scalars, however the SLPVectorizer estimates the cost as 2 insert elements. To fix this I add `TargetTransformInfo::getBuildVectorCost` so the target can optionally specify the exact cost.
1 parent d1dde17 commit ec973a7

File tree

7 files changed

+170
-47
lines changed

7 files changed

+170
-47
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1479,6 +1479,12 @@ class TargetTransformInfo {
14791479
InstructionCost getInsertExtractValueCost(unsigned Opcode,
14801480
TTI::TargetCostKind CostKind) const;
14811481

1482+
/// \return The cost of ISD::BUILD_VECTOR, or nullopt if the cost should be
1483+
/// inferred from insert element and shuffle ops.
1484+
std::optional<InstructionCost>
1485+
getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
1486+
TargetCostKind CostKind) const;
1487+
14821488
/// \return The cost of replication shuffle of \p VF elements typed \p EltTy
14831489
/// \p ReplicationFactor times.
14841490
///
@@ -2224,6 +2230,10 @@ class TargetTransformInfo::Concept {
22242230
TTI::TargetCostKind CostKind,
22252231
unsigned Index) = 0;
22262232

2233+
virtual std::optional<InstructionCost>
2234+
getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
2235+
TargetCostKind CostKind) = 0;
2236+
22272237
virtual InstructionCost
22282238
getReplicationShuffleCost(Type *EltTy, int ReplicationFactor, int VF,
22292239
const APInt &DemandedDstElts,
@@ -2952,6 +2962,12 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
29522962
unsigned Index) override {
29532963
return Impl.getVectorInstrCost(I, Val, CostKind, Index);
29542964
}
2965+
std::optional<InstructionCost>
2966+
getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
2967+
TTI::TargetCostKind CostKind) override {
2968+
return Impl.getBuildVectorCost(VecTy, Operands, CostKind);
2969+
}
2970+
29552971
InstructionCost
29562972
getReplicationShuffleCost(Type *EltTy, int ReplicationFactor, int VF,
29572973
const APInt &DemandedDstElts,

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,12 @@ class TargetTransformInfoImplBase {
739739
return 1;
740740
}
741741

742+
std::optional<InstructionCost>
743+
getBuildVectorCost(VectorType *Val, ArrayRef<Value *> Operands,
744+
TTI::TargetCostKind CostKind) const {
745+
return std::nullopt;
746+
}
747+
742748
unsigned getReplicationShuffleCost(Type *EltTy, int ReplicationFactor, int VF,
743749
const APInt &DemandedDstElts,
744750
TTI::TargetCostKind CostKind) {

llvm/include/llvm/CodeGen/BasicTTIImpl.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,6 +1373,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
13731373
Op1);
13741374
}
13751375

1376+
std::optional<InstructionCost>
1377+
getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
1378+
TTI::TargetCostKind CostKind) {
1379+
return std::nullopt;
1380+
}
1381+
13761382
InstructionCost getReplicationShuffleCost(Type *EltTy, int ReplicationFactor,
13771383
int VF,
13781384
const APInt &DemandedDstElts,

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,6 +1123,13 @@ InstructionCost TargetTransformInfo::getInsertExtractValueCost(
11231123
return Cost;
11241124
}
11251125

1126+
std::optional<InstructionCost>
1127+
TargetTransformInfo::getBuildVectorCost(VectorType *VecTy,
1128+
ArrayRef<Value *> Operands,
1129+
TargetCostKind CostKind) const {
1130+
return TTIImpl->getBuildVectorCost(VecTy, Operands, CostKind);
1131+
}
1132+
11261133
InstructionCost TargetTransformInfo::getReplicationShuffleCost(
11271134
Type *EltTy, int ReplicationFactor, int VF, const APInt &DemandedDstElts,
11281135
TTI::TargetCostKind CostKind) const {

llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
#ifndef LLVM_LIB_TARGET_NVPTX_NVPTXTARGETTRANSFORMINFO_H
1717
#define LLVM_LIB_TARGET_NVPTX_NVPTXTARGETTRANSFORMINFO_H
1818

19-
#include "NVPTXTargetMachine.h"
2019
#include "MCTargetDesc/NVPTXBaseInfo.h"
20+
#include "NVPTXTargetMachine.h"
21+
#include "NVPTXUtilities.h"
2122
#include "llvm/Analysis/TargetTransformInfo.h"
2223
#include "llvm/CodeGen/BasicTTIImpl.h"
2324
#include "llvm/CodeGen/TargetLowering.h"
@@ -100,6 +101,26 @@ class NVPTXTTIImpl : public BasicTTIImplBase<NVPTXTTIImpl> {
100101
TTI::OperandValueInfo Op2Info = {TTI::OK_AnyValue, TTI::OP_None},
101102
ArrayRef<const Value *> Args = {}, const Instruction *CxtI = nullptr);
102103

104+
std::optional<InstructionCost>
105+
getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
106+
TTI::TargetCostKind CostKind) {
107+
if (CostKind != TTI::TCK_RecipThroughput)
108+
return std::nullopt;
109+
auto VT = getTLI()->getValueType(DL, VecTy);
110+
if (all_of(Operands, [](Value *Op) { return isa<Constant>(Op); }))
111+
return TTI::TCC_Free;
112+
if (Isv2x16VT(VT))
113+
return 1; // Single vector mov
114+
if (VT == MVT::v4i8) {
115+
InstructionCost Cost = 3; // 3 x PRMT
116+
for (auto *Op : Operands)
117+
if (!isa<Constant>(Op))
118+
Cost += 1; // zext operand to i32
119+
return Cost;
120+
}
121+
return std::nullopt;
122+
}
123+
103124
void getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
104125
TTI::UnrollingPreferences &UP,
105126
OptimizationRemarkEmitter *ORE);

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10203,6 +10203,9 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
1020310203
if ((!Root && allConstant(VL)) || all_of(VL, IsaPred<UndefValue>))
1020410204
return TTI::TCC_Free;
1020510205
auto *VecTy = getWidenedType(ScalarTy, VL.size());
10206+
if (auto Cost = TTI.getBuildVectorCost(VecTy, VL, CostKind);
10207+
Cost.has_value())
10208+
return *Cost;
1020610209
InstructionCost GatherCost = 0;
1020710210
SmallVector<Value *> Gathers(VL);
1020810211
if (!Root && isSplat(VL)) {
Lines changed: 110 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,123 @@
11
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2-
; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_70 | FileCheck %s
3-
; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_40 | FileCheck %s -check-prefix=NOVECTOR
2+
; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 | FileCheck %s -check-prefixes=VECTOR,SM90
3+
; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_80 | FileCheck %s -check-prefixes=VECTOR,SM80
4+
; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_70 | FileCheck %s -check-prefixes=VECTOR,SM70
5+
; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_50 | FileCheck %s -check-prefixes=NOVECTOR,SM50
46

57
define void @fusion(ptr noalias nocapture align 256 dereferenceable(19267584) %arg, ptr noalias nocapture readonly align 256 dereferenceable(19267584) %arg1, i32 %arg2, i32 %arg3) local_unnamed_addr #0 {
6-
; CHECK-LABEL: @fusion(
7-
; CHECK-NEXT: [[TMP:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6
8-
; CHECK-NEXT: [[TMP4:%.*]] = or i32 [[TMP]], [[ARG3:%.*]]
9-
; CHECK-NEXT: [[TMP5:%.*]] = shl nuw nsw i32 [[TMP4]], 2
10-
; CHECK-NEXT: [[TMP6:%.*]] = zext i32 [[TMP5]] to i64
11-
; CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP6]]
12-
; CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP6]]
13-
; CHECK-NEXT: [[TMP1:%.*]] = load <2 x half>, ptr [[TMP11]], align 8
14-
; CHECK-NEXT: [[TMP2:%.*]] = fmul fast <2 x half> [[TMP1]], splat (half 0xH5380)
15-
; CHECK-NEXT: [[TMP3:%.*]] = fadd fast <2 x half> [[TMP2]], splat (half 0xH57F0)
16-
; CHECK-NEXT: store <2 x half> [[TMP3]], ptr [[TMP16]], align 8
17-
; CHECK-NEXT: ret void
8+
; VECTOR-LABEL: @fusion(
9+
; VECTOR-NEXT: [[TMP1:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6
10+
; VECTOR-NEXT: [[TMP2:%.*]] = or i32 [[TMP1]], [[ARG3:%.*]]
11+
; VECTOR-NEXT: [[TMP3:%.*]] = shl nuw nsw i32 [[TMP2]], 2
12+
; VECTOR-NEXT: [[TMP4:%.*]] = zext i32 [[TMP3]] to i64
13+
; VECTOR-NEXT: [[TMP5:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP4]]
14+
; VECTOR-NEXT: [[TMP6:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP4]]
15+
; VECTOR-NEXT: [[TMP7:%.*]] = load <2 x half>, ptr [[TMP5]], align 8
16+
; VECTOR-NEXT: [[TMP8:%.*]] = fmul fast <2 x half> [[TMP7]], splat (half 0xH5380)
17+
; VECTOR-NEXT: [[TMP9:%.*]] = fadd fast <2 x half> [[TMP8]], splat (half 0xH57F0)
18+
; VECTOR-NEXT: store <2 x half> [[TMP9]], ptr [[TMP6]], align 8
19+
; VECTOR-NEXT: ret void
1820
;
1921
; NOVECTOR-LABEL: @fusion(
20-
; NOVECTOR-NEXT: [[TMP:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6
21-
; NOVECTOR-NEXT: [[TMP4:%.*]] = or i32 [[TMP]], [[ARG3:%.*]]
22-
; NOVECTOR-NEXT: [[TMP5:%.*]] = shl nuw nsw i32 [[TMP4]], 2
23-
; NOVECTOR-NEXT: [[TMP6:%.*]] = zext i32 [[TMP5]] to i64
24-
; NOVECTOR-NEXT: [[TMP7:%.*]] = or disjoint i64 [[TMP6]], 1
25-
; NOVECTOR-NEXT: [[TMP11:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP6]]
26-
; NOVECTOR-NEXT: [[TMP12:%.*]] = load half, ptr [[TMP11]], align 8
22+
; NOVECTOR-NEXT: [[TMP1:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6
23+
; NOVECTOR-NEXT: [[TMP2:%.*]] = or i32 [[TMP1]], [[ARG3:%.*]]
24+
; NOVECTOR-NEXT: [[TMP3:%.*]] = shl nuw nsw i32 [[TMP2]], 2
25+
; NOVECTOR-NEXT: [[TMP4:%.*]] = zext i32 [[TMP3]] to i64
26+
; NOVECTOR-NEXT: [[TMP10:%.*]] = or disjoint i64 [[TMP4]], 1
27+
; NOVECTOR-NEXT: [[TMP5:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP4]]
28+
; NOVECTOR-NEXT: [[TMP7:%.*]] = load half, ptr [[TMP5]], align 8
29+
; NOVECTOR-NEXT: [[TMP8:%.*]] = fmul fast half [[TMP7]], 0xH5380
30+
; NOVECTOR-NEXT: [[TMP9:%.*]] = fadd fast half [[TMP8]], 0xH57F0
31+
; NOVECTOR-NEXT: [[TMP6:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP4]]
32+
; NOVECTOR-NEXT: store half [[TMP9]], ptr [[TMP6]], align 8
33+
; NOVECTOR-NEXT: [[TMP11:%.*]] = getelementptr inbounds half, ptr [[ARG1]], i64 [[TMP10]]
34+
; NOVECTOR-NEXT: [[TMP12:%.*]] = load half, ptr [[TMP11]], align 2
2735
; NOVECTOR-NEXT: [[TMP13:%.*]] = fmul fast half [[TMP12]], 0xH5380
2836
; NOVECTOR-NEXT: [[TMP14:%.*]] = fadd fast half [[TMP13]], 0xH57F0
29-
; NOVECTOR-NEXT: [[TMP16:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP6]]
30-
; NOVECTOR-NEXT: store half [[TMP14]], ptr [[TMP16]], align 8
31-
; NOVECTOR-NEXT: [[TMP17:%.*]] = getelementptr inbounds half, ptr [[ARG1]], i64 [[TMP7]]
32-
; NOVECTOR-NEXT: [[TMP18:%.*]] = load half, ptr [[TMP17]], align 2
33-
; NOVECTOR-NEXT: [[TMP19:%.*]] = fmul fast half [[TMP18]], 0xH5380
34-
; NOVECTOR-NEXT: [[TMP20:%.*]] = fadd fast half [[TMP19]], 0xH57F0
35-
; NOVECTOR-NEXT: [[TMP21:%.*]] = getelementptr inbounds half, ptr [[ARG]], i64 [[TMP7]]
36-
; NOVECTOR-NEXT: store half [[TMP20]], ptr [[TMP21]], align 2
37+
; NOVECTOR-NEXT: [[TMP15:%.*]] = getelementptr inbounds half, ptr [[ARG]], i64 [[TMP10]]
38+
; NOVECTOR-NEXT: store half [[TMP14]], ptr [[TMP15]], align 2
3739
; NOVECTOR-NEXT: ret void
3840
;
39-
%tmp = shl nuw nsw i32 %arg2, 6
40-
%tmp4 = or i32 %tmp, %arg3
41-
%tmp5 = shl nuw nsw i32 %tmp4, 2
42-
%tmp6 = zext i32 %tmp5 to i64
43-
%tmp7 = or disjoint i64 %tmp6, 1
44-
%tmp11 = getelementptr inbounds half, ptr %arg1, i64 %tmp6
45-
%tmp12 = load half, ptr %tmp11, align 8
46-
%tmp13 = fmul fast half %tmp12, 0xH5380
47-
%tmp14 = fadd fast half %tmp13, 0xH57F0
48-
%tmp16 = getelementptr inbounds half, ptr %arg, i64 %tmp6
49-
store half %tmp14, ptr %tmp16, align 8
50-
%tmp17 = getelementptr inbounds half, ptr %arg1, i64 %tmp7
51-
%tmp18 = load half, ptr %tmp17, align 2
52-
%tmp19 = fmul fast half %tmp18, 0xH5380
53-
%tmp20 = fadd fast half %tmp19, 0xH57F0
54-
%tmp21 = getelementptr inbounds half, ptr %arg, i64 %tmp7
55-
store half %tmp20, ptr %tmp21, align 2
41+
%1 = shl nuw nsw i32 %arg2, 6
42+
%4 = or i32 %1, %arg3
43+
%5 = shl nuw nsw i32 %4, 2
44+
%6 = zext i32 %5 to i64
45+
%7 = or disjoint i64 %6, 1
46+
%11 = getelementptr inbounds half, ptr %arg1, i64 %6
47+
%12 = load half, ptr %11, align 8
48+
%13 = fmul fast half %12, 0xH5380
49+
%14 = fadd fast half %13, 0xH57F0
50+
%16 = getelementptr inbounds half, ptr %arg, i64 %6
51+
store half %14, ptr %16, align 8
52+
%17 = getelementptr inbounds half, ptr %arg1, i64 %7
53+
%18 = load half, ptr %17, align 2
54+
%19 = fmul fast half %18, 0xH5380
55+
%20 = fadd fast half %19, 0xH57F0
56+
%21 = getelementptr inbounds half, ptr %arg, i64 %7
57+
store half %20, ptr %21, align 2
5658
ret void
5759
}
5860

61+
define ptx_kernel void @add_f16(ptr addrspace(1) %0, { half, half } %1, { half, half } %2) {
62+
; VECTOR-LABEL: @add_f16(
63+
; VECTOR-NEXT: [[TMP4:%.*]] = extractvalue { half, half } [[TMP1:%.*]], 0
64+
; VECTOR-NEXT: [[TMP5:%.*]] = extractvalue { half, half } [[TMP1]], 1
65+
; VECTOR-NEXT: [[TMP6:%.*]] = extractvalue { half, half } [[TMP2:%.*]], 0
66+
; VECTOR-NEXT: [[TMP7:%.*]] = extractvalue { half, half } [[TMP2]], 1
67+
; VECTOR-NEXT: [[TMP8:%.*]] = insertelement <2 x half> poison, half [[TMP4]], i32 0
68+
; VECTOR-NEXT: [[TMP9:%.*]] = insertelement <2 x half> [[TMP8]], half [[TMP5]], i32 1
69+
; VECTOR-NEXT: [[TMP10:%.*]] = insertelement <2 x half> poison, half [[TMP6]], i32 0
70+
; VECTOR-NEXT: [[TMP11:%.*]] = insertelement <2 x half> [[TMP10]], half [[TMP7]], i32 1
71+
; VECTOR-NEXT: [[TMP12:%.*]] = fadd <2 x half> [[TMP9]], [[TMP11]]
72+
; VECTOR-NEXT: [[TMP13:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
73+
; VECTOR-NEXT: [[TMP14:%.*]] = shl i32 [[TMP13]], 1
74+
; VECTOR-NEXT: [[TMP15:%.*]] = and i32 [[TMP14]], 62
75+
; VECTOR-NEXT: [[TMP16:%.*]] = zext nneg i32 [[TMP15]] to i64
76+
; VECTOR-NEXT: [[TMP17:%.*]] = getelementptr half, ptr addrspace(1) [[TMP0:%.*]], i64 [[TMP16]]
77+
; VECTOR-NEXT: store <2 x half> [[TMP12]], ptr addrspace(1) [[TMP17]], align 4
78+
; VECTOR-NEXT: ret void
79+
;
80+
; NOVECTOR-LABEL: @add_f16(
81+
; NOVECTOR-NEXT: [[TMP4:%.*]] = extractvalue { half, half } [[TMP1:%.*]], 0
82+
; NOVECTOR-NEXT: [[TMP5:%.*]] = extractvalue { half, half } [[TMP1]], 1
83+
; NOVECTOR-NEXT: [[TMP6:%.*]] = extractvalue { half, half } [[TMP2:%.*]], 0
84+
; NOVECTOR-NEXT: [[TMP7:%.*]] = extractvalue { half, half } [[TMP2]], 1
85+
; NOVECTOR-NEXT: [[TMP8:%.*]] = fadd half [[TMP4]], [[TMP6]]
86+
; NOVECTOR-NEXT: [[TMP9:%.*]] = fadd half [[TMP5]], [[TMP7]]
87+
; NOVECTOR-NEXT: [[TMP13:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
88+
; NOVECTOR-NEXT: [[TMP14:%.*]] = shl i32 [[TMP13]], 1
89+
; NOVECTOR-NEXT: [[TMP15:%.*]] = and i32 [[TMP14]], 62
90+
; NOVECTOR-NEXT: [[TMP16:%.*]] = zext nneg i32 [[TMP15]] to i64
91+
; NOVECTOR-NEXT: [[TMP17:%.*]] = getelementptr half, ptr addrspace(1) [[TMP0:%.*]], i64 [[TMP16]]
92+
; NOVECTOR-NEXT: [[TMP19:%.*]] = insertelement <2 x half> poison, half [[TMP8]], i64 0
93+
; NOVECTOR-NEXT: [[TMP12:%.*]] = insertelement <2 x half> [[TMP19]], half [[TMP9]], i64 1
94+
; NOVECTOR-NEXT: store <2 x half> [[TMP12]], ptr addrspace(1) [[TMP17]], align 4
95+
; NOVECTOR-NEXT: ret void
96+
;
97+
%5 = extractvalue { half, half } %1, 0
98+
%6 = extractvalue { half, half } %1, 1
99+
%7 = extractvalue { half, half } %2, 0
100+
%8 = extractvalue { half, half } %2, 1
101+
%9 = fadd half %5, %7
102+
%10 = fadd half %6, %8
103+
%11 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
104+
%12 = shl i32 %11, 1
105+
%13 = and i32 %12, 62
106+
%14 = zext nneg i32 %13 to i64
107+
%15 = getelementptr half, ptr addrspace(1) %0, i64 %14
108+
%18 = insertelement <2 x half> poison, half %9, i64 0
109+
%19 = insertelement <2 x half> %18, half %10, i64 1
110+
store <2 x half> %19, ptr addrspace(1) %15, align 4
111+
ret void
112+
}
113+
114+
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
115+
declare noundef i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1
116+
59117
attributes #0 = { nounwind }
118+
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
119+
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
120+
; SM50: {{.*}}
121+
; SM70: {{.*}}
122+
; SM80: {{.*}}
123+
; SM90: {{.*}}

0 commit comments

Comments
 (0)