Skip to content

Commit 55430f8

Browse files
authored
[NVPTX] Customize getScalarizationOverhead (#128077)
We've observed that the SLPVectorizer is too conservative on NVPTX because it over-estimates the cost to build a vector. PTX has a single `mov` instruction that can build e.g. `<2 x half>` vectors from scalars, however the SLPVectorizer over-estimates it as the cost of 2 insert elements. To fix this I customize `getScalarizationOverhead` to lower the cost for building 2x16 types.
1 parent 22a11be commit 55430f8

File tree

2 files changed

+143
-47
lines changed

2 files changed

+143
-47
lines changed

llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
#ifndef LLVM_LIB_TARGET_NVPTX_NVPTXTARGETTRANSFORMINFO_H
1717
#define LLVM_LIB_TARGET_NVPTX_NVPTXTARGETTRANSFORMINFO_H
1818

19-
#include "NVPTXTargetMachine.h"
2019
#include "MCTargetDesc/NVPTXBaseInfo.h"
20+
#include "NVPTXTargetMachine.h"
21+
#include "NVPTXUtilities.h"
2122
#include "llvm/Analysis/TargetTransformInfo.h"
2223
#include "llvm/CodeGen/BasicTTIImpl.h"
2324
#include "llvm/CodeGen/TargetLowering.h"
@@ -104,6 +105,42 @@ class NVPTXTTIImpl : public BasicTTIImplBase<NVPTXTTIImpl> {
104105
TTI::OperandValueInfo Op2Info = {TTI::OK_AnyValue, TTI::OP_None},
105106
ArrayRef<const Value *> Args = {}, const Instruction *CxtI = nullptr);
106107

108+
InstructionCost getScalarizationOverhead(VectorType *InTy,
109+
const APInt &DemandedElts,
110+
bool Insert, bool Extract,
111+
TTI::TargetCostKind CostKind,
112+
ArrayRef<Value *> VL = {}) {
113+
if (!InTy->getElementCount().isFixed())
114+
return InstructionCost::getInvalid();
115+
116+
auto VT = getTLI()->getValueType(DL, InTy);
117+
auto NumElements = InTy->getElementCount().getFixedValue();
118+
InstructionCost Cost = 0;
119+
if (Insert && !VL.empty()) {
120+
bool AllConstant = all_of(seq(NumElements), [&](int Idx) {
121+
return !DemandedElts[Idx] || isa<Constant>(VL[Idx]);
122+
});
123+
if (AllConstant) {
124+
Cost += TTI::TCC_Free;
125+
Insert = false;
126+
}
127+
}
128+
if (Insert && Isv2x16VT(VT)) {
129+
// Can be built in a single mov
130+
Cost += 1;
131+
Insert = false;
132+
}
133+
if (Insert && VT == MVT::v4i8) {
134+
InstructionCost Cost = 3; // 3 x PRMT
135+
for (auto Idx : seq(NumElements))
136+
if (DemandedElts[Idx])
137+
Cost += 1; // zext operand to i32
138+
Insert = false;
139+
}
140+
return Cost + BaseT::getScalarizationOverhead(InTy, DemandedElts, Insert,
141+
Extract, CostKind, VL);
142+
}
143+
107144
void getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
108145
TTI::UnrollingPreferences &UP,
109146
OptimizationRemarkEmitter *ORE);
Lines changed: 105 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,118 @@
11
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2-
; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_70 | FileCheck %s
3-
; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_40 | FileCheck %s -check-prefix=NOVECTOR
2+
; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 | FileCheck %s -check-prefix=VECTOR
3+
; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_80 | FileCheck %s -check-prefix=VECTOR
4+
; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_70 | FileCheck %s -check-prefix=VECTOR
5+
; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_50 | FileCheck %s -check-prefix=NOVECTOR
46

57
define void @fusion(ptr noalias nocapture align 256 dereferenceable(19267584) %arg, ptr noalias nocapture readonly align 256 dereferenceable(19267584) %arg1, i32 %arg2, i32 %arg3) local_unnamed_addr #0 {
6-
; CHECK-LABEL: @fusion(
7-
; CHECK-NEXT: [[TMP:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6
8-
; CHECK-NEXT: [[TMP4:%.*]] = or i32 [[TMP]], [[ARG3:%.*]]
9-
; CHECK-NEXT: [[TMP5:%.*]] = shl nuw nsw i32 [[TMP4]], 2
10-
; CHECK-NEXT: [[TMP6:%.*]] = zext i32 [[TMP5]] to i64
11-
; CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP6]]
12-
; CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP6]]
13-
; CHECK-NEXT: [[TMP1:%.*]] = load <2 x half>, ptr [[TMP11]], align 8
14-
; CHECK-NEXT: [[TMP2:%.*]] = fmul fast <2 x half> [[TMP1]], splat (half 0xH5380)
15-
; CHECK-NEXT: [[TMP3:%.*]] = fadd fast <2 x half> [[TMP2]], splat (half 0xH57F0)
16-
; CHECK-NEXT: store <2 x half> [[TMP3]], ptr [[TMP16]], align 8
17-
; CHECK-NEXT: ret void
8+
; VECTOR-LABEL: @fusion(
9+
; VECTOR-NEXT: [[TMP1:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6
10+
; VECTOR-NEXT: [[TMP2:%.*]] = or i32 [[TMP1]], [[ARG3:%.*]]
11+
; VECTOR-NEXT: [[TMP3:%.*]] = shl nuw nsw i32 [[TMP2]], 2
12+
; VECTOR-NEXT: [[TMP4:%.*]] = zext i32 [[TMP3]] to i64
13+
; VECTOR-NEXT: [[TMP5:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP4]]
14+
; VECTOR-NEXT: [[TMP6:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP4]]
15+
; VECTOR-NEXT: [[TMP7:%.*]] = load <2 x half>, ptr [[TMP5]], align 8
16+
; VECTOR-NEXT: [[TMP8:%.*]] = fmul fast <2 x half> [[TMP7]], splat (half 0xH5380)
17+
; VECTOR-NEXT: [[TMP9:%.*]] = fadd fast <2 x half> [[TMP8]], splat (half 0xH57F0)
18+
; VECTOR-NEXT: store <2 x half> [[TMP9]], ptr [[TMP6]], align 8
19+
; VECTOR-NEXT: ret void
1820
;
1921
; NOVECTOR-LABEL: @fusion(
20-
; NOVECTOR-NEXT: [[TMP:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6
21-
; NOVECTOR-NEXT: [[TMP4:%.*]] = or i32 [[TMP]], [[ARG3:%.*]]
22-
; NOVECTOR-NEXT: [[TMP5:%.*]] = shl nuw nsw i32 [[TMP4]], 2
23-
; NOVECTOR-NEXT: [[TMP6:%.*]] = zext i32 [[TMP5]] to i64
24-
; NOVECTOR-NEXT: [[TMP7:%.*]] = or disjoint i64 [[TMP6]], 1
25-
; NOVECTOR-NEXT: [[TMP11:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP6]]
26-
; NOVECTOR-NEXT: [[TMP12:%.*]] = load half, ptr [[TMP11]], align 8
22+
; NOVECTOR-NEXT: [[TMP1:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6
23+
; NOVECTOR-NEXT: [[TMP2:%.*]] = or i32 [[TMP1]], [[ARG3:%.*]]
24+
; NOVECTOR-NEXT: [[TMP3:%.*]] = shl nuw nsw i32 [[TMP2]], 2
25+
; NOVECTOR-NEXT: [[TMP4:%.*]] = zext i32 [[TMP3]] to i64
26+
; NOVECTOR-NEXT: [[TMP10:%.*]] = or disjoint i64 [[TMP4]], 1
27+
; NOVECTOR-NEXT: [[TMP5:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP4]]
28+
; NOVECTOR-NEXT: [[TMP7:%.*]] = load half, ptr [[TMP5]], align 8
29+
; NOVECTOR-NEXT: [[TMP8:%.*]] = fmul fast half [[TMP7]], 0xH5380
30+
; NOVECTOR-NEXT: [[TMP9:%.*]] = fadd fast half [[TMP8]], 0xH57F0
31+
; NOVECTOR-NEXT: [[TMP6:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP4]]
32+
; NOVECTOR-NEXT: store half [[TMP9]], ptr [[TMP6]], align 8
33+
; NOVECTOR-NEXT: [[TMP11:%.*]] = getelementptr inbounds half, ptr [[ARG1]], i64 [[TMP10]]
34+
; NOVECTOR-NEXT: [[TMP12:%.*]] = load half, ptr [[TMP11]], align 2
2735
; NOVECTOR-NEXT: [[TMP13:%.*]] = fmul fast half [[TMP12]], 0xH5380
2836
; NOVECTOR-NEXT: [[TMP14:%.*]] = fadd fast half [[TMP13]], 0xH57F0
29-
; NOVECTOR-NEXT: [[TMP16:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP6]]
30-
; NOVECTOR-NEXT: store half [[TMP14]], ptr [[TMP16]], align 8
31-
; NOVECTOR-NEXT: [[TMP17:%.*]] = getelementptr inbounds half, ptr [[ARG1]], i64 [[TMP7]]
32-
; NOVECTOR-NEXT: [[TMP18:%.*]] = load half, ptr [[TMP17]], align 2
33-
; NOVECTOR-NEXT: [[TMP19:%.*]] = fmul fast half [[TMP18]], 0xH5380
34-
; NOVECTOR-NEXT: [[TMP20:%.*]] = fadd fast half [[TMP19]], 0xH57F0
35-
; NOVECTOR-NEXT: [[TMP21:%.*]] = getelementptr inbounds half, ptr [[ARG]], i64 [[TMP7]]
36-
; NOVECTOR-NEXT: store half [[TMP20]], ptr [[TMP21]], align 2
37+
; NOVECTOR-NEXT: [[TMP15:%.*]] = getelementptr inbounds half, ptr [[ARG]], i64 [[TMP10]]
38+
; NOVECTOR-NEXT: store half [[TMP14]], ptr [[TMP15]], align 2
3739
; NOVECTOR-NEXT: ret void
3840
;
39-
%tmp = shl nuw nsw i32 %arg2, 6
40-
%tmp4 = or i32 %tmp, %arg3
41-
%tmp5 = shl nuw nsw i32 %tmp4, 2
42-
%tmp6 = zext i32 %tmp5 to i64
43-
%tmp7 = or disjoint i64 %tmp6, 1
44-
%tmp11 = getelementptr inbounds half, ptr %arg1, i64 %tmp6
45-
%tmp12 = load half, ptr %tmp11, align 8
46-
%tmp13 = fmul fast half %tmp12, 0xH5380
47-
%tmp14 = fadd fast half %tmp13, 0xH57F0
48-
%tmp16 = getelementptr inbounds half, ptr %arg, i64 %tmp6
49-
store half %tmp14, ptr %tmp16, align 8
50-
%tmp17 = getelementptr inbounds half, ptr %arg1, i64 %tmp7
51-
%tmp18 = load half, ptr %tmp17, align 2
52-
%tmp19 = fmul fast half %tmp18, 0xH5380
53-
%tmp20 = fadd fast half %tmp19, 0xH57F0
54-
%tmp21 = getelementptr inbounds half, ptr %arg, i64 %tmp7
55-
store half %tmp20, ptr %tmp21, align 2
41+
%1 = shl nuw nsw i32 %arg2, 6
42+
%4 = or i32 %1, %arg3
43+
%5 = shl nuw nsw i32 %4, 2
44+
%6 = zext i32 %5 to i64
45+
%7 = or disjoint i64 %6, 1
46+
%11 = getelementptr inbounds half, ptr %arg1, i64 %6
47+
%12 = load half, ptr %11, align 8
48+
%13 = fmul fast half %12, 0xH5380
49+
%14 = fadd fast half %13, 0xH57F0
50+
%16 = getelementptr inbounds half, ptr %arg, i64 %6
51+
store half %14, ptr %16, align 8
52+
%17 = getelementptr inbounds half, ptr %arg1, i64 %7
53+
%18 = load half, ptr %17, align 2
54+
%19 = fmul fast half %18, 0xH5380
55+
%20 = fadd fast half %19, 0xH57F0
56+
%21 = getelementptr inbounds half, ptr %arg, i64 %7
57+
store half %20, ptr %21, align 2
5658
ret void
5759
}
5860

61+
define ptx_kernel void @add_f16(ptr addrspace(1) %0, { half, half } %1, { half, half } %2) {
62+
; VECTOR-LABEL: @add_f16(
63+
; VECTOR-NEXT: [[TMP4:%.*]] = extractvalue { half, half } [[TMP1:%.*]], 0
64+
; VECTOR-NEXT: [[TMP5:%.*]] = extractvalue { half, half } [[TMP1]], 1
65+
; VECTOR-NEXT: [[TMP6:%.*]] = extractvalue { half, half } [[TMP2:%.*]], 0
66+
; VECTOR-NEXT: [[TMP7:%.*]] = extractvalue { half, half } [[TMP2]], 1
67+
; VECTOR-NEXT: [[TMP8:%.*]] = insertelement <2 x half> poison, half [[TMP4]], i32 0
68+
; VECTOR-NEXT: [[TMP9:%.*]] = insertelement <2 x half> [[TMP8]], half [[TMP5]], i32 1
69+
; VECTOR-NEXT: [[TMP10:%.*]] = insertelement <2 x half> poison, half [[TMP6]], i32 0
70+
; VECTOR-NEXT: [[TMP11:%.*]] = insertelement <2 x half> [[TMP10]], half [[TMP7]], i32 1
71+
; VECTOR-NEXT: [[TMP12:%.*]] = fadd <2 x half> [[TMP9]], [[TMP11]]
72+
; VECTOR-NEXT: [[TMP13:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
73+
; VECTOR-NEXT: [[TMP14:%.*]] = shl i32 [[TMP13]], 1
74+
; VECTOR-NEXT: [[TMP15:%.*]] = and i32 [[TMP14]], 62
75+
; VECTOR-NEXT: [[TMP16:%.*]] = zext nneg i32 [[TMP15]] to i64
76+
; VECTOR-NEXT: [[TMP17:%.*]] = getelementptr half, ptr addrspace(1) [[TMP0:%.*]], i64 [[TMP16]]
77+
; VECTOR-NEXT: store <2 x half> [[TMP12]], ptr addrspace(1) [[TMP17]], align 4
78+
; VECTOR-NEXT: ret void
79+
;
80+
; NOVECTOR-LABEL: @add_f16(
81+
; NOVECTOR-NEXT: [[TMP4:%.*]] = extractvalue { half, half } [[TMP1:%.*]], 0
82+
; NOVECTOR-NEXT: [[TMP5:%.*]] = extractvalue { half, half } [[TMP1]], 1
83+
; NOVECTOR-NEXT: [[TMP6:%.*]] = extractvalue { half, half } [[TMP2:%.*]], 0
84+
; NOVECTOR-NEXT: [[TMP7:%.*]] = extractvalue { half, half } [[TMP2]], 1
85+
; NOVECTOR-NEXT: [[TMP8:%.*]] = fadd half [[TMP4]], [[TMP6]]
86+
; NOVECTOR-NEXT: [[TMP9:%.*]] = fadd half [[TMP5]], [[TMP7]]
87+
; NOVECTOR-NEXT: [[TMP13:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
88+
; NOVECTOR-NEXT: [[TMP14:%.*]] = shl i32 [[TMP13]], 1
89+
; NOVECTOR-NEXT: [[TMP15:%.*]] = and i32 [[TMP14]], 62
90+
; NOVECTOR-NEXT: [[TMP16:%.*]] = zext nneg i32 [[TMP15]] to i64
91+
; NOVECTOR-NEXT: [[TMP17:%.*]] = getelementptr half, ptr addrspace(1) [[TMP0:%.*]], i64 [[TMP16]]
92+
; NOVECTOR-NEXT: [[TMP19:%.*]] = insertelement <2 x half> poison, half [[TMP8]], i64 0
93+
; NOVECTOR-NEXT: [[TMP12:%.*]] = insertelement <2 x half> [[TMP19]], half [[TMP9]], i64 1
94+
; NOVECTOR-NEXT: store <2 x half> [[TMP12]], ptr addrspace(1) [[TMP17]], align 4
95+
; NOVECTOR-NEXT: ret void
96+
;
97+
%5 = extractvalue { half, half } %1, 0
98+
%6 = extractvalue { half, half } %1, 1
99+
%7 = extractvalue { half, half } %2, 0
100+
%8 = extractvalue { half, half } %2, 1
101+
%9 = fadd half %5, %7
102+
%10 = fadd half %6, %8
103+
%11 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
104+
%12 = shl i32 %11, 1
105+
%13 = and i32 %12, 62
106+
%14 = zext nneg i32 %13 to i64
107+
%15 = getelementptr half, ptr addrspace(1) %0, i64 %14
108+
%18 = insertelement <2 x half> poison, half %9, i64 0
109+
%19 = insertelement <2 x half> %18, half %10, i64 1
110+
store <2 x half> %19, ptr addrspace(1) %15, align 4
111+
ret void
112+
}
113+
114+
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
115+
declare noundef i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1
116+
59117
attributes #0 = { nounwind }
118+
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }

0 commit comments

Comments
 (0)