Skip to content

Commit 294ef57

Browse files
committed
[LLVM] Replace half precision calls to min/max/fma with custom intrinsics
1 parent 5c0d176 commit 294ef57

File tree

1 file changed

+49
-15
lines changed

1 file changed

+49
-15
lines changed

src/llvm_coop_vec.cpp

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -224,24 +224,49 @@ void jitc_llvm_render_coop_vec(const Variable *v, const Variable *a0,
224224
for (uint32_t i = 0; i < v->array_length; ++i)
225225
fmt(" $v_$u = $s $V_$u, $v_$u\n", v, i, op, a0, i, a1, i);
226226
} else {
227-
fmt_intrinsic("declare $T @llvm.$s.v$w$h($T, $T)", v, op, v, a0, a1);
227+
bool custom_intrinsic = false;
228+
#if !defined(__aarch64__)
229+
if ((VarType) a0->type == VarType::Float16) {
230+
if ((JitOp) v->literal == JitOp::Min)
231+
def_minnum_vec_f16_intrinsic();
232+
else
233+
def_maxnum_vec_f16_intrinsic();
234+
custom_intrinsic = true;
235+
}
236+
#endif
237+
if (!custom_intrinsic)
238+
fmt_intrinsic("declare $T @llvm.$s.v$w$h($T, $T)",
239+
v, op, v, a0, a1);
240+
241+
const char *intrinsic_prefix = custom_intrinsic ? "" : "llvm.";
228242
for (uint32_t i = 0; i < v->array_length; ++i)
229-
fmt(" $v_$u = call fast $T @llvm.$s.v$w$h($V_$u, $V_$u)\n",
230-
v, i, v, op, v, a0, i, a1, i);
243+
fmt(" $v_$u = call fast $T @$s$s.v$w$h($V_$u, $V_$u)\n",
244+
v, i, v, intrinsic_prefix, op, v, a0, i, a1, i);
231245
}
232246
}
233247
}
234248
break;
235249

236-
case VarKind::CoopVecTernaryOp:
237-
if ((JitOp) v->literal != JitOp::Fma)
238-
jitc_fail("CoopVecTernaryOp: unsupported operation!");
250+
case VarKind::CoopVecTernaryOp: {
251+
if ((JitOp) v->literal != JitOp::Fma)
252+
jitc_fail("CoopVecTernaryOp: unsupported operation!");
239253

240-
fmt_intrinsic("declare $T @llvm.fma.v$w$h($T, $T, $T)", v, v,
241-
a0, a1, a2);
242-
for (uint32_t i = 0; i < v->array_length; ++i)
243-
fmt(" $v_$u = call $T @llvm.fma.v$w$h($V_$u, $V_$u, $V_$u)\n",
244-
v, i, v, v, a0, i, a1, i, a2, i);
254+
bool custom_intrinsic = false;
255+
#if !defined(__aarch64__)
256+
if ((VarType) a0->type == VarType::Float16) {
257+
def_fma_vec_f16_intrinsic();
258+
custom_intrinsic = true;
259+
}
260+
#endif
261+
if (!custom_intrinsic)
262+
fmt_intrinsic("declare $T @llvm.fma.v$w$h($T, $T, $T)",
263+
v, v, a0, a1, a2);
264+
265+
const char *intrinsic_prefix = custom_intrinsic ? "" : "llvm.";
266+
for (uint32_t i = 0; i < v->array_length; ++i)
267+
fmt(" $v_$u = call $T @$sfma.v$w$h($V_$u, $V_$u, $V_$u)\n",
268+
v, i, v, intrinsic_prefix, v, a0, i, a1, i, a2, i);
269+
}
245270
break;
246271

247272
case VarKind::Bitcast:
@@ -424,16 +449,25 @@ void jitc_llvm_render_coop_vec(const Variable *v, const Variable *a0,
424449
v, v, v, v, v, v, mask, v);
425450
}
426451

427-
fmt_intrinsic("declare $T @llvm.fma.v$w$h($T, $T, $T)", v, v,
428-
v, v, v);
452+
bool custom_intrinsic = false;
453+
#if !defined(__aarch64__)
454+
if ((VarType) v->type == VarType::Float16) {
455+
def_fma_vec_f16_intrinsic();
456+
custom_intrinsic = true;
457+
}
458+
#endif
459+
if (!custom_intrinsic)
460+
fmt_intrinsic("declare $T @llvm.fma.v$w$h($T, $T, $T)",
461+
v, v, v, v, v);
429462

463+
const char *intrinsic_prefix = custom_intrinsic ? "" : "llvm.";
430464
fmt(" $v_y1 = getelementptr inbounds $T, {$T*} $v_po, i32 $v_i\n"
431465
" $v_y = load $T, {$T*} $v_y1, align $A\n"
432-
" $v_r = call $T @llvm.fma.v$w$h($V_a, $V_x, $V_y)\n"
466+
" $v_r = call $T @$sfma.v$w$h($V_a, $V_x, $V_y)\n"
433467
" store $V_r, {$T*} $v_y1, align $A\n",
434468
v, v, v, v, v,
435469
v, v, v, v, v,
436-
v, v, v, v, v, v,
470+
v, v, intrinsic_prefix, v, v, v, v,
437471
v, v, v, v);
438472

439473
fmt(" br label %l$u_inner\n"

0 commit comments

Comments
 (0)