Skip to content

Commit 0dc2148

Browse files
[NVPTX] Add 3-operand fmin/fmax DAGCombines (#159729)
Add DAGCombiner patterns for pairs of 2-operand min/max instructions to be fused into a single 3-operand min/max instruction for f32s (only for PTX 8.8+ and sm100+).
1 parent bb79448 commit 0dc2148

File tree

2 files changed

+326
-4
lines changed

2 files changed

+326
-4
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -841,10 +841,14 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
841841
setOperationAction(ISD::UMUL_LOHI, MVT::i64, Expand);
842842

843843
// We have some custom DAG combine patterns for these nodes
844-
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
845-
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
846-
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD,
847-
ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND});
844+
setTargetDAGCombine(
845+
{ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT,
846+
ISD::FADD, ISD::FMAXNUM, ISD::FMINNUM,
847+
ISD::FMAXIMUM, ISD::FMINIMUM, ISD::FMAXIMUMNUM,
848+
ISD::FMINIMUMNUM, ISD::MUL, ISD::SHL,
849+
ISD::SREM, ISD::UREM, ISD::VSELECT,
850+
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD,
851+
ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND});
848852

849853
// setcc for f16x2 and bf16x2 needs special handling to prevent
850854
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5316,6 +5320,56 @@ static SDValue PerformFADDCombine(SDNode *N,
53165320
return PerformFADDCombineWithOperands(N, N1, N0, DCI, OptLevel);
53175321
}
53185322

5323+
/// Get 3-input version of a 2-input min/max opcode
5324+
static NVPTXISD::NodeType getMinMax3Opcode(unsigned MinMax2Opcode) {
5325+
switch (MinMax2Opcode) {
5326+
case ISD::FMAXNUM:
5327+
case ISD::FMAXIMUMNUM:
5328+
return NVPTXISD::FMAXNUM3;
5329+
case ISD::FMINNUM:
5330+
case ISD::FMINIMUMNUM:
5331+
return NVPTXISD::FMINNUM3;
5332+
case ISD::FMAXIMUM:
5333+
return NVPTXISD::FMAXIMUM3;
5334+
case ISD::FMINIMUM:
5335+
return NVPTXISD::FMINIMUM3;
5336+
default:
5337+
llvm_unreachable("Invalid 2-input min/max opcode");
5338+
}
5339+
}
5340+
5341+
/// PerformFMinMaxCombine - Combine (fmaxnum (fmaxnum a, b), c) into
5342+
/// (fmaxnum3 a, b, c). Also covers other llvm min/max intrinsics.
5343+
static SDValue PerformFMinMaxCombine(SDNode *N,
5344+
TargetLowering::DAGCombinerInfo &DCI,
5345+
unsigned PTXVersion, unsigned SmVersion) {
5346+
5347+
// 3-input min/max requires PTX 8.8+ and SM_100+, and only supports f32s
5348+
EVT VT = N->getValueType(0);
5349+
if (VT != MVT::f32 || PTXVersion < 88 || SmVersion < 100)
5350+
return SDValue();
5351+
5352+
SDValue Op0 = N->getOperand(0);
5353+
SDValue Op1 = N->getOperand(1);
5354+
unsigned MinMaxOp2 = N->getOpcode();
5355+
NVPTXISD::NodeType MinMaxOp3 = getMinMax3Opcode(MinMaxOp2);
5356+
5357+
if (Op0.getOpcode() == MinMaxOp2 && Op0.hasOneUse()) {
5358+
// (maxnum (maxnum a, b), c) -> (maxnum3 a, b, c)
5359+
SDValue A = Op0.getOperand(0);
5360+
SDValue B = Op0.getOperand(1);
5361+
SDValue C = Op1;
5362+
return DCI.DAG.getNode(MinMaxOp3, SDLoc(N), VT, A, B, C, N->getFlags());
5363+
} else if (Op1.getOpcode() == MinMaxOp2 && Op1.hasOneUse()) {
5364+
// (maxnum a, (maxnum b, c)) -> (maxnum3 a, b, c)
5365+
SDValue A = Op0;
5366+
SDValue B = Op1.getOperand(0);
5367+
SDValue C = Op1.getOperand(1);
5368+
return DCI.DAG.getNode(MinMaxOp3, SDLoc(N), VT, A, B, C, N->getFlags());
5369+
}
5370+
return SDValue();
5371+
}
5372+
53195373
static SDValue PerformREMCombine(SDNode *N,
53205374
TargetLowering::DAGCombinerInfo &DCI,
53215375
CodeGenOptLevel OptLevel) {
@@ -5996,6 +6050,14 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
59966050
return PerformEXTRACTCombine(N, DCI);
59976051
case ISD::FADD:
59986052
return PerformFADDCombine(N, DCI, OptLevel);
6053+
case ISD::FMAXNUM:
6054+
case ISD::FMINNUM:
6055+
case ISD::FMAXIMUM:
6056+
case ISD::FMINIMUM:
6057+
case ISD::FMAXIMUMNUM:
6058+
case ISD::FMINIMUMNUM:
6059+
return PerformFMinMaxCombine(N, DCI, STI.getPTXVersion(),
6060+
STI.getSmVersion());
59996061
case ISD::LOAD:
60006062
case NVPTXISD::LoadV2:
60016063
case NVPTXISD::LoadV4:

llvm/test/CodeGen/NVPTX/fmax3.ll

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
2+
; RUN: llc -march=nvptx64 -mcpu=sm_100f -o - %s | FileCheck %s
3+
4+
target triple = "nvptx64-nvidia-cuda"
5+
target datalayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64"
6+
7+
define void @test_fmaxnum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
8+
; CHECK-LABEL: test_fmaxnum3(
9+
; CHECK: {
10+
; CHECK-NEXT: .reg .b32 %r<5>;
11+
; CHECK-NEXT: .reg .b64 %rd<2>;
12+
; CHECK-EMPTY:
13+
; CHECK-NEXT: // %bb.0: // %entry
14+
; CHECK-NEXT: ld.param.b32 %r1, [test_fmaxnum3_param_0];
15+
; CHECK-NEXT: ld.param.b32 %r2, [test_fmaxnum3_param_1];
16+
; CHECK-NEXT: ld.param.b32 %r3, [test_fmaxnum3_param_2];
17+
; CHECK-NEXT: max.f32 %r4, %r1, %r2, %r3;
18+
; CHECK-NEXT: ld.param.b64 %rd1, [test_fmaxnum3_param_3];
19+
; CHECK-NEXT: st.global.b32 [%rd1], %r4;
20+
; CHECK-NEXT: ret;
21+
entry:
22+
%max_ab = call float @llvm.maxnum.f32(float %a, float %b)
23+
%max_abc = call float @llvm.maxnum.f32(float %max_ab, float %c)
24+
store float %max_abc, ptr addrspace(1) %output, align 4
25+
ret void
26+
}
27+
28+
define void @test_fminnum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
29+
; CHECK-LABEL: test_fminnum3(
30+
; CHECK: {
31+
; CHECK-NEXT: .reg .b32 %r<5>;
32+
; CHECK-NEXT: .reg .b64 %rd<2>;
33+
; CHECK-EMPTY:
34+
; CHECK-NEXT: // %bb.0: // %entry
35+
; CHECK-NEXT: ld.param.b32 %r1, [test_fminnum3_param_0];
36+
; CHECK-NEXT: ld.param.b32 %r2, [test_fminnum3_param_1];
37+
; CHECK-NEXT: ld.param.b32 %r3, [test_fminnum3_param_2];
38+
; CHECK-NEXT: min.f32 %r4, %r1, %r2, %r3;
39+
; CHECK-NEXT: ld.param.b64 %rd1, [test_fminnum3_param_3];
40+
; CHECK-NEXT: st.global.b32 [%rd1], %r4;
41+
; CHECK-NEXT: ret;
42+
entry:
43+
%min_ab = call float @llvm.minnum.f32(float %a, float %b)
44+
%min_abc = call float @llvm.minnum.f32(float %min_ab, float %c)
45+
store float %min_abc, ptr addrspace(1) %output, align 4
46+
ret void
47+
}
48+
49+
define void @test_fmaximum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
50+
; CHECK-LABEL: test_fmaximum3(
51+
; CHECK: {
52+
; CHECK-NEXT: .reg .b32 %r<5>;
53+
; CHECK-NEXT: .reg .b64 %rd<2>;
54+
; CHECK-EMPTY:
55+
; CHECK-NEXT: // %bb.0: // %entry
56+
; CHECK-NEXT: ld.param.b32 %r1, [test_fmaximum3_param_0];
57+
; CHECK-NEXT: ld.param.b32 %r2, [test_fmaximum3_param_1];
58+
; CHECK-NEXT: ld.param.b32 %r3, [test_fmaximum3_param_2];
59+
; CHECK-NEXT: max.NaN.f32 %r4, %r1, %r2, %r3;
60+
; CHECK-NEXT: ld.param.b64 %rd1, [test_fmaximum3_param_3];
61+
; CHECK-NEXT: st.global.b32 [%rd1], %r4;
62+
; CHECK-NEXT: ret;
63+
entry:
64+
%max_ab = call float @llvm.maximum.f32(float %a, float %b)
65+
%max_abc = call float @llvm.maximum.f32(float %max_ab, float %c)
66+
store float %max_abc, ptr addrspace(1) %output, align 4
67+
ret void
68+
}
69+
70+
define void @test_fminimum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
71+
; CHECK-LABEL: test_fminimum3(
72+
; CHECK: {
73+
; CHECK-NEXT: .reg .b32 %r<5>;
74+
; CHECK-NEXT: .reg .b64 %rd<2>;
75+
; CHECK-EMPTY:
76+
; CHECK-NEXT: // %bb.0: // %entry
77+
; CHECK-NEXT: ld.param.b32 %r1, [test_fminimum3_param_0];
78+
; CHECK-NEXT: ld.param.b32 %r2, [test_fminimum3_param_1];
79+
; CHECK-NEXT: ld.param.b32 %r3, [test_fminimum3_param_2];
80+
; CHECK-NEXT: min.NaN.f32 %r4, %r1, %r2, %r3;
81+
; CHECK-NEXT: ld.param.b64 %rd1, [test_fminimum3_param_3];
82+
; CHECK-NEXT: st.global.b32 [%rd1], %r4;
83+
; CHECK-NEXT: ret;
84+
entry:
85+
%min_ab = call float @llvm.minimum.f32(float %a, float %b)
86+
%min_abc = call float @llvm.minimum.f32(float %min_ab, float %c)
87+
store float %min_abc, ptr addrspace(1) %output, align 4
88+
ret void
89+
}
90+
91+
define void @test_fmaximumnum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
92+
; CHECK-LABEL: test_fmaximumnum3(
93+
; CHECK: {
94+
; CHECK-NEXT: .reg .b32 %r<5>;
95+
; CHECK-NEXT: .reg .b64 %rd<2>;
96+
; CHECK-EMPTY:
97+
; CHECK-NEXT: // %bb.0: // %entry
98+
; CHECK-NEXT: ld.param.b32 %r1, [test_fmaximumnum3_param_0];
99+
; CHECK-NEXT: ld.param.b32 %r2, [test_fmaximumnum3_param_1];
100+
; CHECK-NEXT: ld.param.b32 %r3, [test_fmaximumnum3_param_2];
101+
; CHECK-NEXT: max.f32 %r4, %r1, %r2, %r3;
102+
; CHECK-NEXT: ld.param.b64 %rd1, [test_fmaximumnum3_param_3];
103+
; CHECK-NEXT: st.global.b32 [%rd1], %r4;
104+
; CHECK-NEXT: ret;
105+
entry:
106+
%max_ab = call float @llvm.maximumnum.f32(float %a, float %b)
107+
%max_abc = call float @llvm.maximumnum.f32(float %max_ab, float %c)
108+
store float %max_abc, ptr addrspace(1) %output, align 4
109+
ret void
110+
}
111+
112+
define void @test_fminimumnum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
113+
; CHECK-LABEL: test_fminimumnum3(
114+
; CHECK: {
115+
; CHECK-NEXT: .reg .b32 %r<5>;
116+
; CHECK-NEXT: .reg .b64 %rd<2>;
117+
; CHECK-EMPTY:
118+
; CHECK-NEXT: // %bb.0: // %entry
119+
; CHECK-NEXT: ld.param.b32 %r1, [test_fminimumnum3_param_0];
120+
; CHECK-NEXT: ld.param.b32 %r2, [test_fminimumnum3_param_1];
121+
; CHECK-NEXT: ld.param.b32 %r3, [test_fminimumnum3_param_2];
122+
; CHECK-NEXT: min.f32 %r4, %r1, %r2, %r3;
123+
; CHECK-NEXT: ld.param.b64 %rd1, [test_fminimumnum3_param_3];
124+
; CHECK-NEXT: st.global.b32 [%rd1], %r4;
125+
; CHECK-NEXT: ret;
126+
entry:
127+
%min_ab = call float @llvm.minimumnum.f32(float %a, float %b)
128+
%min_abc = call float @llvm.minimumnum.f32(float %min_ab, float %c)
129+
store float %min_abc, ptr addrspace(1) %output, align 4
130+
ret void
131+
}
132+
133+
; Test commuted operands (second operand is the nested operation)
134+
define void @test_fmaxnum3_commuted(float %a, float %b, float %c, ptr addrspace(1) %output) {
135+
; CHECK-LABEL: test_fmaxnum3_commuted(
136+
; CHECK: {
137+
; CHECK-NEXT: .reg .b32 %r<5>;
138+
; CHECK-NEXT: .reg .b64 %rd<2>;
139+
; CHECK-EMPTY:
140+
; CHECK-NEXT: // %bb.0: // %entry
141+
; CHECK-NEXT: ld.param.b32 %r1, [test_fmaxnum3_commuted_param_0];
142+
; CHECK-NEXT: ld.param.b32 %r2, [test_fmaxnum3_commuted_param_1];
143+
; CHECK-NEXT: ld.param.b32 %r3, [test_fmaxnum3_commuted_param_2];
144+
; CHECK-NEXT: max.f32 %r4, %r1, %r2, %r3;
145+
; CHECK-NEXT: ld.param.b64 %rd1, [test_fmaxnum3_commuted_param_3];
146+
; CHECK-NEXT: st.global.b32 [%rd1], %r4;
147+
; CHECK-NEXT: ret;
148+
entry:
149+
%max_bc = call float @llvm.maxnum.f32(float %b, float %c)
150+
%max_abc = call float @llvm.maxnum.f32(float %a, float %max_bc)
151+
store float %max_abc, ptr addrspace(1) %output, align 4
152+
ret void
153+
}
154+
155+
; NEGATIVE TEST: Mixed min/max operations should not combine
156+
define void @test_mixed_minmax_no_combine(float %a, float %b, float %c, ptr addrspace(1) %output) {
157+
; CHECK-LABEL: test_mixed_minmax_no_combine(
158+
; CHECK: {
159+
; CHECK-NEXT: .reg .b32 %r<6>;
160+
; CHECK-NEXT: .reg .b64 %rd<2>;
161+
; CHECK-EMPTY:
162+
; CHECK-NEXT: // %bb.0: // %entry
163+
; CHECK-NEXT: ld.param.b32 %r1, [test_mixed_minmax_no_combine_param_0];
164+
; CHECK-NEXT: ld.param.b32 %r2, [test_mixed_minmax_no_combine_param_1];
165+
; CHECK-NEXT: min.f32 %r3, %r1, %r2;
166+
; CHECK-NEXT: ld.param.b32 %r4, [test_mixed_minmax_no_combine_param_2];
167+
; CHECK-NEXT: max.f32 %r5, %r3, %r4;
168+
; CHECK-NEXT: ld.param.b64 %rd1, [test_mixed_minmax_no_combine_param_3];
169+
; CHECK-NEXT: st.global.b32 [%rd1], %r5;
170+
; CHECK-NEXT: ret;
171+
entry:
172+
%min_ab = call float @llvm.minnum.f32(float %a, float %b)
173+
%max_result = call float @llvm.maxnum.f32(float %min_ab, float %c)
174+
store float %max_result, ptr addrspace(1) %output, align 4
175+
ret void
176+
}
177+
178+
; NEGATIVE TEST: Mixed maxnum/maximum operations should not combine
179+
define void @test_mixed_maxnum_maximum_no_combine(float %a, float %b, float %c, ptr addrspace(1) %output) {
180+
; CHECK-LABEL: test_mixed_maxnum_maximum_no_combine(
181+
; CHECK: {
182+
; CHECK-NEXT: .reg .b32 %r<6>;
183+
; CHECK-NEXT: .reg .b64 %rd<2>;
184+
; CHECK-EMPTY:
185+
; CHECK-NEXT: // %bb.0: // %entry
186+
; CHECK-NEXT: ld.param.b32 %r1, [test_mixed_maxnum_maximum_no_combine_param_0];
187+
; CHECK-NEXT: ld.param.b32 %r2, [test_mixed_maxnum_maximum_no_combine_param_1];
188+
; CHECK-NEXT: max.f32 %r3, %r1, %r2;
189+
; CHECK-NEXT: ld.param.b32 %r4, [test_mixed_maxnum_maximum_no_combine_param_2];
190+
; CHECK-NEXT: max.NaN.f32 %r5, %r3, %r4;
191+
; CHECK-NEXT: ld.param.b64 %rd1, [test_mixed_maxnum_maximum_no_combine_param_3];
192+
; CHECK-NEXT: st.global.b32 [%rd1], %r5;
193+
; CHECK-NEXT: ret;
194+
entry:
195+
%maxnum_ab = call float @llvm.maxnum.f32(float %a, float %b)
196+
%maximum_result = call float @llvm.maximum.f32(float %maxnum_ab, float %c)
197+
store float %maximum_result, ptr addrspace(1) %output, align 4
198+
ret void
199+
}
200+
201+
; NEGATIVE TEST: f16 should not be combined (only f32 supported)
202+
define void @test_f16_no_combine(half %a, half %b, half %c, ptr addrspace(1) %output) {
203+
; CHECK-LABEL: test_f16_no_combine(
204+
; CHECK: {
205+
; CHECK-NEXT: .reg .b16 %rs<6>;
206+
; CHECK-NEXT: .reg .b64 %rd<2>;
207+
; CHECK-EMPTY:
208+
; CHECK-NEXT: // %bb.0: // %entry
209+
; CHECK-NEXT: ld.param.b16 %rs1, [test_f16_no_combine_param_0];
210+
; CHECK-NEXT: ld.param.b16 %rs2, [test_f16_no_combine_param_1];
211+
; CHECK-NEXT: max.f16 %rs3, %rs1, %rs2;
212+
; CHECK-NEXT: ld.param.b16 %rs4, [test_f16_no_combine_param_2];
213+
; CHECK-NEXT: max.f16 %rs5, %rs3, %rs4;
214+
; CHECK-NEXT: ld.param.b64 %rd1, [test_f16_no_combine_param_3];
215+
; CHECK-NEXT: st.global.b16 [%rd1], %rs5;
216+
; CHECK-NEXT: ret;
217+
entry:
218+
%max_ab = call half @llvm.maxnum.f16(half %a, half %b)
219+
%max_abc = call half @llvm.maxnum.f16(half %max_ab, half %c)
220+
store half %max_abc, ptr addrspace(1) %output, align 2
221+
ret void
222+
}
223+
224+
; NEGATIVE TEST: Multiple uses of intermediate result should not combine
225+
define void @test_multiple_uses_no_combine(float %a, float %b, float %c, ptr addrspace(1) %output1, ptr addrspace(1) %output2) {
226+
; CHECK-LABEL: test_multiple_uses_no_combine(
227+
; CHECK: {
228+
; CHECK-NEXT: .reg .b32 %r<6>;
229+
; CHECK-NEXT: .reg .b64 %rd<3>;
230+
; CHECK-EMPTY:
231+
; CHECK-NEXT: // %bb.0: // %entry
232+
; CHECK-NEXT: ld.param.b32 %r1, [test_multiple_uses_no_combine_param_0];
233+
; CHECK-NEXT: ld.param.b32 %r2, [test_multiple_uses_no_combine_param_1];
234+
; CHECK-NEXT: max.f32 %r3, %r1, %r2;
235+
; CHECK-NEXT: ld.param.b32 %r4, [test_multiple_uses_no_combine_param_2];
236+
; CHECK-NEXT: max.f32 %r5, %r3, %r4;
237+
; CHECK-NEXT: ld.param.b64 %rd1, [test_multiple_uses_no_combine_param_3];
238+
; CHECK-NEXT: st.global.b32 [%rd1], %r3;
239+
; CHECK-NEXT: ld.param.b64 %rd2, [test_multiple_uses_no_combine_param_4];
240+
; CHECK-NEXT: st.global.b32 [%rd2], %r5;
241+
; CHECK-NEXT: ret;
242+
entry:
243+
%max_ab = call float @llvm.maxnum.f32(float %a, float %b)
244+
%max_abc = call float @llvm.maxnum.f32(float %max_ab, float %c)
245+
; Multiple uses of %max_ab should prevent combining
246+
store float %max_ab, ptr addrspace(1) %output1, align 4
247+
store float %max_abc, ptr addrspace(1) %output2, align 4
248+
ret void
249+
}
250+
251+
; Declare all the intrinsics we need
252+
declare float @llvm.maxnum.f32(float, float) #0
253+
declare float @llvm.minnum.f32(float, float) #0
254+
declare float @llvm.maximum.f32(float, float) #0
255+
declare float @llvm.minimum.f32(float, float) #0
256+
declare float @llvm.maximumnum.f32(float, float) #0
257+
declare float @llvm.minimumnum.f32(float, float) #0
258+
declare half @llvm.maxnum.f16(half, half) #0
259+
260+
attributes #0 = { nounwind readnone speculatable willreturn }

0 commit comments

Comments
 (0)