Skip to content

Commit 9f9ad7f

Browse files
committed
[NVPTX] Add fma mix precision intrinsics
This change adds "fma" mix precision operations.
1 parent 497382e commit 9f9ad7f

File tree

3 files changed

+319
-0
lines changed

3 files changed

+319
-0
lines changed

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,6 +1180,26 @@ let TargetPrefix = "nvvm" in {
11801180
[IntrNoMem, IntrSpeculatable]>;
11811181
}
11821182

1183+
// Mixed-precision fma intrinsics for half and bfloat16 to float
1184+
foreach rnd = ["rn", "rz", "rm", "rp"] in {
1185+
foreach sat = ["", "_sat"] in {
1186+
// Half-precision to float
1187+
def int_nvvm_fma_#rnd#sat#_h_f
1188+
: ClangBuiltin<"__nvvm_fma_"#rnd#sat#"_h_f">,
1189+
DefaultAttrsIntrinsic<[llvm_float_ty],
1190+
[llvm_half_ty, llvm_half_ty, llvm_float_ty],
1191+
[IntrNoMem, IntrSpeculatable]>;
1192+
1193+
// BFloat16 to float
1194+
def int_nvvm_fma_#rnd#sat#_bf_f
1195+
: ClangBuiltin<"__nvvm_fma_"#rnd#sat#"_bf_f">,
1196+
DefaultAttrsIntrinsic<[llvm_float_ty],
1197+
[llvm_bfloat_ty, llvm_bfloat_ty,
1198+
llvm_float_ty],
1199+
[IntrNoMem, IntrSpeculatable]>;
1200+
}
1201+
}
1202+
11831203
//
11841204
// Rcp
11851205
//

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1510,6 +1510,27 @@ multiclass FMA_INST {
15101510

15111511
defm INT_NVVM_FMA : FMA_INST;
15121512

1513+
// Define mixed-precision fma instructions for half and bfloat16 to float
1514+
foreach rnd = ["rn", "rz", "rm", "rp"] in {
1515+
foreach sat = ["", "_sat"] in {
1516+
// Half-precision to float
1517+
def INT_NVVM_FMA_#!toupper(rnd#sat)#_H_F
1518+
: F_MATH_3<"fma."#rnd#!subst(
1519+
"_", ".", sat)#".f32.f16 \t$dst, $src0, $src1, $src2;",
1520+
Float32Regs, Int16Regs, Int16Regs, Float32Regs,
1521+
!cast<Intrinsic>("int_nvvm_fma_"#rnd#sat#"_h_f"),
1522+
[hasPTX<86>, hasSM<100>]>;
1523+
1524+
// BFloat16 to float
1525+
def INT_NVVM_FMA_#!toupper(rnd#sat)#_BF_F
1526+
: F_MATH_3<"fma."#rnd#!subst(
1527+
"_", ".", sat)#".f32.bf16 \t$dst, $src0, $src1, $src2;",
1528+
Float32Regs, Int16Regs, Int16Regs, Float32Regs,
1529+
!cast<Intrinsic>("int_nvvm_fma_"#rnd#sat#"_bf_f"),
1530+
[hasPTX<86>, hasSM<100>]>;
1531+
}
1532+
}
1533+
15131534
//
15141535
// Rcp
15151536
//
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx86 | FileCheck %s
3+
4+
; Basic f32.f16 variants with different rounding modes
5+
define float @test_fma_rn_h_f(half %a, half %b, float %c) {
6+
; CHECK-LABEL: test_fma_rn_h_f(
7+
; CHECK: {
8+
; CHECK-NEXT: .reg .b16 %rs<3>;
9+
; CHECK-NEXT: .reg .f32 %f<3>;
10+
; CHECK-EMPTY:
11+
; CHECK-NEXT: // %bb.0:
12+
; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rn_h_f_param_0];
13+
; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rn_h_f_param_1];
14+
; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rn_h_f_param_2];
15+
; CHECK-NEXT: fma.rn.f32.f16 %f2, %rs1, %rs2, %f1;
16+
; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
17+
; CHECK-NEXT: ret;
18+
%res = call float @llvm.nvvm.fma.rn.h.f(half %a, half %b, float %c)
19+
ret float %res
20+
}
21+
22+
define float @test_fma_rz_h_f(half %a, half %b, float %c) {
23+
; CHECK-LABEL: test_fma_rz_h_f(
24+
; CHECK: {
25+
; CHECK-NEXT: .reg .b16 %rs<3>;
26+
; CHECK-NEXT: .reg .f32 %f<3>;
27+
; CHECK-EMPTY:
28+
; CHECK-NEXT: // %bb.0:
29+
; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rz_h_f_param_0];
30+
; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rz_h_f_param_1];
31+
; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rz_h_f_param_2];
32+
; CHECK-NEXT: fma.rz.f32.f16 %f2, %rs1, %rs2, %f1;
33+
; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
34+
; CHECK-NEXT: ret;
35+
%res = call float @llvm.nvvm.fma.rz.h.f(half %a, half %b, float %c)
36+
ret float %res
37+
}
38+
39+
define float @test_fma_rm_h_f(half %a, half %b, float %c) {
40+
; CHECK-LABEL: test_fma_rm_h_f(
41+
; CHECK: {
42+
; CHECK-NEXT: .reg .b16 %rs<3>;
43+
; CHECK-NEXT: .reg .f32 %f<3>;
44+
; CHECK-EMPTY:
45+
; CHECK-NEXT: // %bb.0:
46+
; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rm_h_f_param_0];
47+
; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rm_h_f_param_1];
48+
; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rm_h_f_param_2];
49+
; CHECK-NEXT: fma.rm.f32.f16 %f2, %rs1, %rs2, %f1;
50+
; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
51+
; CHECK-NEXT: ret;
52+
%res = call float @llvm.nvvm.fma.rm.h.f(half %a, half %b, float %c)
53+
ret float %res
54+
}
55+
56+
define float @test_fma_rp_h_f(half %a, half %b, float %c) {
57+
; CHECK-LABEL: test_fma_rp_h_f(
58+
; CHECK: {
59+
; CHECK-NEXT: .reg .b16 %rs<3>;
60+
; CHECK-NEXT: .reg .f32 %f<3>;
61+
; CHECK-EMPTY:
62+
; CHECK-NEXT: // %bb.0:
63+
; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rp_h_f_param_0];
64+
; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rp_h_f_param_1];
65+
; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rp_h_f_param_2];
66+
; CHECK-NEXT: fma.rp.f32.f16 %f2, %rs1, %rs2, %f1;
67+
; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
68+
; CHECK-NEXT: ret;
69+
%res = call float @llvm.nvvm.fma.rp.h.f(half %a, half %b, float %c)
70+
ret float %res
71+
}
72+
73+
; Basic f32.bf16 variants with different rounding modes
74+
define float @test_fma_rn_bf_f(bfloat %a, bfloat %b, float %c) {
75+
; CHECK-LABEL: test_fma_rn_bf_f(
76+
; CHECK: {
77+
; CHECK-NEXT: .reg .b16 %rs<3>;
78+
; CHECK-NEXT: .reg .f32 %f<3>;
79+
; CHECK-EMPTY:
80+
; CHECK-NEXT: // %bb.0:
81+
; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rn_bf_f_param_0];
82+
; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rn_bf_f_param_1];
83+
; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rn_bf_f_param_2];
84+
; CHECK-NEXT: fma.rn.f32.bf16 %f2, %rs1, %rs2, %f1;
85+
; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
86+
; CHECK-NEXT: ret;
87+
%res = call float @llvm.nvvm.fma.rn.bf.f(bfloat %a, bfloat %b, float %c)
88+
ret float %res
89+
}
90+
91+
define float @test_fma_rz_bf_f(bfloat %a, bfloat %b, float %c) {
92+
; CHECK-LABEL: test_fma_rz_bf_f(
93+
; CHECK: {
94+
; CHECK-NEXT: .reg .b16 %rs<3>;
95+
; CHECK-NEXT: .reg .f32 %f<3>;
96+
; CHECK-EMPTY:
97+
; CHECK-NEXT: // %bb.0:
98+
; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rz_bf_f_param_0];
99+
; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rz_bf_f_param_1];
100+
; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rz_bf_f_param_2];
101+
; CHECK-NEXT: fma.rz.f32.bf16 %f2, %rs1, %rs2, %f1;
102+
; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
103+
; CHECK-NEXT: ret;
104+
%res = call float @llvm.nvvm.fma.rz.bf.f(bfloat %a, bfloat %b, float %c)
105+
ret float %res
106+
}
107+
108+
define float @test_fma_rm_bf_f(bfloat %a, bfloat %b, float %c) {
109+
; CHECK-LABEL: test_fma_rm_bf_f(
110+
; CHECK: {
111+
; CHECK-NEXT: .reg .b16 %rs<3>;
112+
; CHECK-NEXT: .reg .f32 %f<3>;
113+
; CHECK-EMPTY:
114+
; CHECK-NEXT: // %bb.0:
115+
; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rm_bf_f_param_0];
116+
; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rm_bf_f_param_1];
117+
; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rm_bf_f_param_2];
118+
; CHECK-NEXT: fma.rm.f32.bf16 %f2, %rs1, %rs2, %f1;
119+
; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
120+
; CHECK-NEXT: ret;
121+
%res = call float @llvm.nvvm.fma.rm.bf.f(bfloat %a, bfloat %b, float %c)
122+
ret float %res
123+
}
124+
125+
define float @test_fma_rp_bf_f(bfloat %a, bfloat %b, float %c) {
126+
; CHECK-LABEL: test_fma_rp_bf_f(
127+
; CHECK: {
128+
; CHECK-NEXT: .reg .b16 %rs<3>;
129+
; CHECK-NEXT: .reg .f32 %f<3>;
130+
; CHECK-EMPTY:
131+
; CHECK-NEXT: // %bb.0:
132+
; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rp_bf_f_param_0];
133+
; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rp_bf_f_param_1];
134+
; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rp_bf_f_param_2];
135+
; CHECK-NEXT: fma.rp.f32.bf16 %f2, %rs1, %rs2, %f1;
136+
; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
137+
; CHECK-NEXT: ret;
138+
%res = call float @llvm.nvvm.fma.rp.bf.f(bfloat %a, bfloat %b, float %c)
139+
ret float %res
140+
}
141+
142+
; f32.f16 variants with sat flag
143+
define float @test_fma_rn_sat_h_f(half %a, half %b, float %c) {
144+
; CHECK-LABEL: test_fma_rn_sat_h_f(
145+
; CHECK: {
146+
; CHECK-NEXT: .reg .b16 %rs<3>;
147+
; CHECK-NEXT: .reg .f32 %f<3>;
148+
; CHECK-EMPTY:
149+
; CHECK-NEXT: // %bb.0:
150+
; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rn_sat_h_f_param_0];
151+
; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rn_sat_h_f_param_1];
152+
; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rn_sat_h_f_param_2];
153+
; CHECK-NEXT: fma.rn.sat.f32.f16 %f2, %rs1, %rs2, %f1;
154+
; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
155+
; CHECK-NEXT: ret;
156+
%res = call float @llvm.nvvm.fma.rn.sat.h.f(half %a, half %b, float %c)
157+
ret float %res
158+
}
159+
160+
define float @test_fma_rz_sat_h_f(half %a, half %b, float %c) {
161+
; CHECK-LABEL: test_fma_rz_sat_h_f(
162+
; CHECK: {
163+
; CHECK-NEXT: .reg .b16 %rs<3>;
164+
; CHECK-NEXT: .reg .f32 %f<3>;
165+
; CHECK-EMPTY:
166+
; CHECK-NEXT: // %bb.0:
167+
; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rz_sat_h_f_param_0];
168+
; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rz_sat_h_f_param_1];
169+
; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rz_sat_h_f_param_2];
170+
; CHECK-NEXT: fma.rz.sat.f32.f16 %f2, %rs1, %rs2, %f1;
171+
; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
172+
; CHECK-NEXT: ret;
173+
%res = call float @llvm.nvvm.fma.rz.sat.h.f(half %a, half %b, float %c)
174+
ret float %res
175+
}
176+
177+
define float @test_fma_rm_sat_h_f(half %a, half %b, float %c) {
178+
; CHECK-LABEL: test_fma_rm_sat_h_f(
179+
; CHECK: {
180+
; CHECK-NEXT: .reg .b16 %rs<3>;
181+
; CHECK-NEXT: .reg .f32 %f<3>;
182+
; CHECK-EMPTY:
183+
; CHECK-NEXT: // %bb.0:
184+
; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rm_sat_h_f_param_0];
185+
; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rm_sat_h_f_param_1];
186+
; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rm_sat_h_f_param_2];
187+
; CHECK-NEXT: fma.rm.sat.f32.f16 %f2, %rs1, %rs2, %f1;
188+
; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
189+
; CHECK-NEXT: ret;
190+
%res = call float @llvm.nvvm.fma.rm.sat.h.f(half %a, half %b, float %c)
191+
ret float %res
192+
}
193+
194+
define float @test_fma_rp_sat_h_f(half %a, half %b, float %c) {
195+
; CHECK-LABEL: test_fma_rp_sat_h_f(
196+
; CHECK: {
197+
; CHECK-NEXT: .reg .b16 %rs<3>;
198+
; CHECK-NEXT: .reg .f32 %f<3>;
199+
; CHECK-EMPTY:
200+
; CHECK-NEXT: // %bb.0:
201+
; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rp_sat_h_f_param_0];
202+
; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rp_sat_h_f_param_1];
203+
; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rp_sat_h_f_param_2];
204+
; CHECK-NEXT: fma.rp.sat.f32.f16 %f2, %rs1, %rs2, %f1;
205+
; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
206+
; CHECK-NEXT: ret;
207+
%res = call float @llvm.nvvm.fma.rp.sat.h.f(half %a, half %b, float %c)
208+
ret float %res
209+
}
210+
211+
; f32.bf16 variants with sat flag
212+
define float @test_fma_rn_sat_bf_f(bfloat %a, bfloat %b, float %c) {
213+
; CHECK-LABEL: test_fma_rn_sat_bf_f(
214+
; CHECK: {
215+
; CHECK-NEXT: .reg .b16 %rs<3>;
216+
; CHECK-NEXT: .reg .f32 %f<3>;
217+
; CHECK-EMPTY:
218+
; CHECK-NEXT: // %bb.0:
219+
; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rn_sat_bf_f_param_0];
220+
; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rn_sat_bf_f_param_1];
221+
; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rn_sat_bf_f_param_2];
222+
; CHECK-NEXT: fma.rn.sat.f32.bf16 %f2, %rs1, %rs2, %f1;
223+
; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
224+
; CHECK-NEXT: ret;
225+
%res = call float @llvm.nvvm.fma.rn.sat.bf.f(bfloat %a, bfloat %b, float %c)
226+
ret float %res
227+
}
228+
229+
define float @test_fma_rz_sat_bf_f(bfloat %a, bfloat %b, float %c) {
230+
; CHECK-LABEL: test_fma_rz_sat_bf_f(
231+
; CHECK: {
232+
; CHECK-NEXT: .reg .b16 %rs<3>;
233+
; CHECK-NEXT: .reg .f32 %f<3>;
234+
; CHECK-EMPTY:
235+
; CHECK-NEXT: // %bb.0:
236+
; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rz_sat_bf_f_param_0];
237+
; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rz_sat_bf_f_param_1];
238+
; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rz_sat_bf_f_param_2];
239+
; CHECK-NEXT: fma.rz.sat.f32.bf16 %f2, %rs1, %rs2, %f1;
240+
; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
241+
; CHECK-NEXT: ret;
242+
%res = call float @llvm.nvvm.fma.rz.sat.bf.f(bfloat %a, bfloat %b, float %c)
243+
ret float %res
244+
}
245+
246+
define float @test_fma_rm_sat_bf_f(bfloat %a, bfloat %b, float %c) {
247+
; CHECK-LABEL: test_fma_rm_sat_bf_f(
248+
; CHECK: {
249+
; CHECK-NEXT: .reg .b16 %rs<3>;
250+
; CHECK-NEXT: .reg .f32 %f<3>;
251+
; CHECK-EMPTY:
252+
; CHECK-NEXT: // %bb.0:
253+
; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rm_sat_bf_f_param_0];
254+
; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rm_sat_bf_f_param_1];
255+
; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rm_sat_bf_f_param_2];
256+
; CHECK-NEXT: fma.rm.sat.f32.bf16 %f2, %rs1, %rs2, %f1;
257+
; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
258+
; CHECK-NEXT: ret;
259+
%res = call float @llvm.nvvm.fma.rm.sat.bf.f(bfloat %a, bfloat %b, float %c)
260+
ret float %res
261+
}
262+
263+
define float @test_fma_rp_sat_bf_f(bfloat %a, bfloat %b, float %c) {
264+
; CHECK-LABEL: test_fma_rp_sat_bf_f(
265+
; CHECK: {
266+
; CHECK-NEXT: .reg .b16 %rs<3>;
267+
; CHECK-NEXT: .reg .f32 %f<3>;
268+
; CHECK-EMPTY:
269+
; CHECK-NEXT: // %bb.0:
270+
; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rp_sat_bf_f_param_0];
271+
; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rp_sat_bf_f_param_1];
272+
; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rp_sat_bf_f_param_2];
273+
; CHECK-NEXT: fma.rp.sat.f32.bf16 %f2, %rs1, %rs2, %f1;
274+
; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
275+
; CHECK-NEXT: ret;
276+
%res = call float @llvm.nvvm.fma.rp.sat.bf.f(bfloat %a, bfloat %b, float %c)
277+
ret float %res
278+
}

0 commit comments

Comments
 (0)