Skip to content

Commit 604eab1

Browse files
committed
[LLVM] Add generic half precision intrinsic wrappers
1 parent 294ef57 commit 604eab1

File tree

2 files changed

+28
-40
lines changed

2 files changed

+28
-40
lines changed

src/llvm_coop_vec.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -227,10 +227,7 @@ void jitc_llvm_render_coop_vec(const Variable *v, const Variable *a0,
227227
bool custom_intrinsic = false;
228228
#if !defined(__aarch64__)
229229
if ((VarType) a0->type == VarType::Float16) {
230-
if ((JitOp) v->literal == JitOp::Min)
231-
def_minnum_vec_f16_intrinsic();
232-
else
233-
def_maxnum_vec_f16_intrinsic();
230+
def_f16_wrapper_binary_intrinsic(op);
234231
custom_intrinsic = true;
235232
}
236233
#endif
@@ -254,7 +251,7 @@ void jitc_llvm_render_coop_vec(const Variable *v, const Variable *a0,
254251
bool custom_intrinsic = false;
255252
#if !defined(__aarch64__)
256253
if ((VarType) a0->type == VarType::Float16) {
257-
def_fma_vec_f16_intrinsic();
254+
def_f16_wrapper_ternary_intrinsic("fma");
258255
custom_intrinsic = true;
259256
}
260257
#endif
@@ -452,7 +449,7 @@ void jitc_llvm_render_coop_vec(const Variable *v, const Variable *a0,
452449
bool custom_intrinsic = false;
453450
#if !defined(__aarch64__)
454451
if ((VarType) v->type == VarType::Float16) {
455-
def_fma_vec_f16_intrinsic();
452+
def_f16_wrapper_ternary_intrinsic("fma");
456453
custom_intrinsic = true;
457454
}
458455
#endif

src/llvm_eval.h

Lines changed: 25 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,39 +13,30 @@
1313
} while (0)
1414

1515
#if !defined(__aarch64__)
16+
#define def_f16_wrapper_binary_intrinsic(op_str) \
17+
do { \
18+
fmt_intrinsic("declare <$w x float> @llvm.$s.v$wf32(<$w x float>, <$w x float>)", op_str); \
19+
fmt_intrinsic( \
20+
"define internal <$w x half> @$s.v$wf16(<$w x half> %a, <$w x half> %b) #0 ${\n" \
21+
" %a_f32 = fpext <$w x half> %a to <$w x float>\n" \
22+
" %b_f32 = fpext <$w x half> %b to <$w x float>\n" \
23+
" %out_f32 = call fast <$w x float> @llvm.$s.v$wf32(<$w x float> %a_f32, <$w x float> %b_f32)\n" \
24+
" %out = fptrunc <$w x float> %out_f32 to <$w x half>\n" \
25+
" ret <$w x half> %out\n" \
26+
"$}", op_str, op_str); \
27+
} while (0)
1628

17-
#define def_fma_vec_f16_intrinsic() \
18-
fmt_intrinsic( \
19-
"define internal <$w x half> @fma.v$wf16(<$w x half> %a, <$w x half> %b, <$w x half> %c) #0 ${\n" \
20-
" %a_f32 = fpext <$w x half> %a to <$w x float>\n" \
21-
" %b_f32 = fpext <$w x half> %b to <$w x float>\n" \
22-
" %c_f32 = fpext <$w x half> %c to <$w x float>\n" \
23-
" %out_f32 = call fast <$w x float> @llvm.fma.v$wf32(<$w x float> %a_f32, <$w x float> %b_f32, <$w x float> %c_f32)\n" \
24-
" %out = fptrunc <$w x float> %out_f32 to <$w x half>\n" \
25-
" ret <$w x half> %out\n" \
26-
"$}" \
27-
)
28-
29-
#define def_minnum_vec_f16_intrinsic() \
30-
fmt_intrinsic( \
31-
"define internal <$w x half> @minnum.v$wf16(<$w x half> %a, <$w x half> %b) local_unnamed_addr #0 ${\n" \
32-
" %a_f32 = fpext <$w x half> %a to <$w x float>\n" \
33-
" %b_f32 = fpext <$w x half> %b to <$w x float>\n" \
34-
" %out_f32 = call fast <$w x float> @llvm.minnum.v$wf32(<$w x float> %a_f32, <$w x float> %b_f32)\n" \
35-
" %out = fptrunc <$w x float> %out_f32 to <$w x half>\n" \
36-
" ret <$w x half> %out\n" \
37-
"$}" \
38-
)
39-
40-
#define def_maxnum_vec_f16_intrinsic() \
41-
fmt_intrinsic( \
42-
"define internal <$w x half> @maxnum.v$wf16(<$w x half> %a, <$w x half> %b) local_unnamed_addr #0 ${\n" \
43-
" %a_f32 = fpext <$w x half> %a to <$w x float>\n" \
44-
" %b_f32 = fpext <$w x half> %b to <$w x float>\n" \
45-
" %out_f32 = call fast <$w x float> @llvm.maxnum.v$wf32(<$w x float> %a_f32, <$w x float> %b_f32)\n" \
46-
" %out = fptrunc <$w x float> %out_f32 to <$w x half>\n" \
47-
" ret <$w x half> %out\n" \
48-
"$}" \
49-
)
50-
29+
#define def_f16_wrapper_ternary_intrinsic(op_str) \
30+
do { \
31+
fmt_intrinsic("declare <$w x float> @llvm.$s.v$wf32(<$w x float>, <$w x float>, <$w x float>)", op_str); \
32+
fmt_intrinsic( \
33+
"define internal <$w x half> @$s.v$wf16(<$w x half> %a, <$w x half> %b, <$w x half> %c) #0 ${\n" \
34+
" %a_f32 = fpext <$w x half> %a to <$w x float>\n" \
35+
" %b_f32 = fpext <$w x half> %b to <$w x float>\n" \
36+
" %c_f32 = fpext <$w x half> %c to <$w x float>\n" \
37+
" %out_f32 = call fast <$w x float> @llvm.$s.v$wf32(<$w x float> %a_f32, <$w x float> %b_f32, <$w x float> %c_f32)\n" \
38+
" %out = fptrunc <$w x float> %out_f32 to <$w x half>\n" \
39+
" ret <$w x half> %out\n" \
40+
"$}", op_str, op_str); \
41+
} while (0)
5142
#endif

0 commit comments

Comments
 (0)