Skip to content

Commit 67e992d

Browse files
xal-0vtjnash
andauthored
Runtime intrinsics: fix fpext and fptrunc behaviour on Float16/BFloat16 (#57160)
This makes two changes `fpext` and `fptrunc` to match the behaviour specified in their error strings: -`fpext` works when converting from Float16 => Float16, -`fptrunc` is prevented from truncating Float16 => Float16 Both are re-written to make it explicit what conversions are possible, and how they are done. Closes #57130. --------- Co-authored-by: Jameson Nash <[email protected]>
1 parent fea26dd commit 67e992d

File tree

5 files changed

+250
-73
lines changed

5 files changed

+250
-73
lines changed

Compiler/src/tfuncs.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2454,6 +2454,9 @@ const _SPECIAL_BUILTINS = Any[
24542454
Core._apply_iterate,
24552455
]
24562456

2457+
# Types compatible with fpext/fptrunc
2458+
const CORE_FLOAT_TYPES = Union{Core.BFloat16, Float16, Float32, Float64}
2459+
24572460
function isdefined_effects(𝕃::AbstractLattice, argtypes::Vector{Any})
24582461
# consistent if the first arg is immutable
24592462
na = length(argtypes)
@@ -2867,6 +2870,17 @@ function intrinsic_exct(𝕃::AbstractLattice, f::IntrinsicFunction, argtypes::V
28672870
if !(isprimitivetype(ty) && isprimitivetype(xty))
28682871
return ErrorException
28692872
end
2873+
2874+
# fpext and fptrunc have further restrictions on the allowed types.
2875+
if f === Intrinsics.fpext &&
2876+
!(ty <: CORE_FLOAT_TYPES && xty <: CORE_FLOAT_TYPES && Core.sizeof(ty) > Core.sizeof(xty))
2877+
return ErrorException
2878+
end
2879+
if f === Intrinsics.fptrunc &&
2880+
!(ty <: CORE_FLOAT_TYPES && xty <: CORE_FLOAT_TYPES && Core.sizeof(ty) < Core.sizeof(xty))
2881+
return ErrorException
2882+
end
2883+
28702884
return Union{}
28712885
end
28722886

Compiler/test/effects.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,3 +1384,14 @@ end |> Compiler.is_nothrow
13841384
@test Base.infer_effects() do
13851385
@ccall unsafecall()::Cvoid
13861386
end == Compiler.EFFECTS_UNKNOWN
1387+
1388+
# fpext
1389+
@test Compiler.intrinsic_nothrow(Core.Intrinsics.fpext, Any[Type{Float32}, Float16])
1390+
@test Compiler.intrinsic_nothrow(Core.Intrinsics.fpext, Any[Type{Float64}, Float16])
1391+
@test Compiler.intrinsic_nothrow(Core.Intrinsics.fpext, Any[Type{Float64}, Float32])
1392+
@test !Compiler.intrinsic_nothrow(Core.Intrinsics.fpext, Any[Type{Float16}, Float16])
1393+
@test !Compiler.intrinsic_nothrow(Core.Intrinsics.fpext, Any[Type{Float16}, Float32])
1394+
@test !Compiler.intrinsic_nothrow(Core.Intrinsics.fpext, Any[Type{Float32}, Float32])
1395+
@test !Compiler.intrinsic_nothrow(Core.Intrinsics.fpext, Any[Type{Float32}, Float64])
1396+
@test !Compiler.intrinsic_nothrow(Core.Intrinsics.fpext, Any[Type{Int32}, Float16])
1397+
@test !Compiler.intrinsic_nothrow(Core.Intrinsics.fpext, Any[Type{Float32}, Int16])

src/intrinsics.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -672,16 +672,23 @@ static jl_cgval_t generic_cast(
672672
uint32_t nb = jl_datatype_size(jlto);
673673
Type *to = bitstype_to_llvm((jl_value_t*)jlto, ctx.builder.getContext(), true);
674674
Type *vt = bitstype_to_llvm(v.typ, ctx.builder.getContext(), true);
675-
if (toint)
676-
to = INTT(to, DL);
677-
else
678-
to = FLOATT(to);
679-
if (fromint)
680-
vt = INTT(vt, DL);
681-
else
682-
vt = FLOATT(vt);
675+
676+
// fptrunc fpext depend on the specific floating point format to work
677+
// correctly, and so do not pun their argument types.
678+
if (!(f == fpext || f == fptrunc)) {
679+
if (toint)
680+
to = INTT(to, DL);
681+
else
682+
to = FLOATT(to);
683+
if (fromint)
684+
vt = INTT(vt, DL);
685+
else
686+
vt = FLOATT(vt);
687+
}
688+
683689
if (!to || !vt)
684690
return emit_runtime_call(ctx, f, argv, 2);
691+
685692
Value *from = emit_unbox(ctx, vt, v, v.typ);
686693
if (!CastInst::castIsValid(Op, from, to))
687694
return emit_runtime_call(ctx, f, argv, 2);

src/runtime_intrinsics.c

Lines changed: 91 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,11 @@ static inline uint16_t float_to_half(float param) JL_NOTSAFEPOINT
161161
uint32_t f;
162162
memcpy(&f, &param, sizeof(float));
163163
if (isnan(param)) {
164-
uint32_t t = 0x8000 ^ (0x8000 & ((uint16_t)(f >> 0x10)));
165-
return t ^ ((uint16_t)(f >> 0xd));
164+
// Match the behaviour of arm64's fcvt or x86's vcvtps2ph by quieting
165+
// all NaNs (avoids creating infinities), preserving the sign, and using
166+
// the upper bits of the payload.
167+
// sign exp quiet payload
168+
return (f>>16 & 0x8000) | 0x7c00 | 0x0200 | (f>>13 & 0x03ff);
166169
}
167170
int i = ((f & ~0x007fffff) >> 23);
168171
uint8_t sh = shifttable[i];
@@ -761,33 +764,25 @@ static inline void name(unsigned osize, jl_value_t *ty, void *pa, void *pr) JL_N
761764
OP(ty, (c_type*)pr, a); \
762765
}
763766

764-
#define un_fintrinsic_half(OP, name) \
765-
static inline void name(unsigned osize, jl_value_t *ty, void *pa, void *pr) JL_NOTSAFEPOINT \
766-
{ \
767-
uint16_t a = *(uint16_t*)pa; \
768-
float A = half_to_float(a); \
769-
if (osize == 16) { \
770-
float R; \
771-
OP(ty, &R, A); \
772-
*(uint16_t*)pr = float_to_half(R); \
773-
} else { \
774-
OP(ty, (uint16_t*)pr, A); \
775-
} \
776-
}
767+
#define un_fintrinsic_half(OP, name) \
768+
static inline void name(unsigned osize, jl_value_t *ty, void *pa, void *pr) \
769+
JL_NOTSAFEPOINT \
770+
{ \
771+
uint16_t a = *(uint16_t *)pa; \
772+
float R, A = half_to_float(a); \
773+
OP(ty, &R, A); \
774+
*(uint16_t *)pr = float_to_half(R); \
775+
}
777776

778-
#define un_fintrinsic_bfloat(OP, name) \
779-
static inline void name(unsigned osize, jl_value_t *ty, void *pa, void *pr) JL_NOTSAFEPOINT \
780-
{ \
781-
uint16_t a = *(uint16_t*)pa; \
782-
float A = bfloat_to_float(a); \
783-
if (osize == 16) { \
784-
float R; \
785-
OP(ty, &R, A); \
786-
*(uint16_t*)pr = float_to_bfloat(R); \
787-
} else { \
788-
OP(ty, (uint16_t*)pr, A); \
789-
} \
790-
}
777+
#define un_fintrinsic_bfloat(OP, name) \
778+
static inline void name(unsigned osize, jl_value_t *ty, void *pa, void *pr) \
779+
JL_NOTSAFEPOINT \
780+
{ \
781+
uint16_t a = *(uint16_t *)pa; \
782+
float R, A = bfloat_to_float(a); \
783+
OP(ty, &R, A); \
784+
*(uint16_t *)pr = float_to_bfloat(R); \
785+
}
791786

792787
// float or integer inputs
793788
// OP::Function macro(inputa, inputb)
@@ -1629,32 +1624,74 @@ cvt_iintrinsic(LLVMUItoFP, uitofp)
16291624
cvt_iintrinsic(LLVMFPtoSI, fptosi)
16301625
cvt_iintrinsic(LLVMFPtoUI, fptoui)
16311626

1632-
#define fptrunc(tr, pr, a) \
1633-
if (!(osize < 8 * sizeof(a))) \
1634-
jl_error("fptrunc: output bitsize must be < input bitsize"); \
1635-
else if (osize == 16) { \
1636-
if ((jl_datatype_t*)tr == jl_float16_type) \
1637-
*(uint16_t*)pr = float_to_half(a); \
1638-
else /*if ((jl_datatype_t*)tr == jl_bfloat16_type)*/ \
1639-
*(uint16_t*)pr = float_to_bfloat(a); \
1640-
} \
1641-
else if (osize == 32) \
1642-
*(float*)pr = a; \
1643-
else if (osize == 64) \
1644-
*(double*)pr = a; \
1645-
else \
1646-
jl_error("fptrunc: runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64");
1647-
#define fpext(tr, pr, a) \
1648-
if (!(osize >= 8 * sizeof(a))) \
1649-
jl_error("fpext: output bitsize must be >= input bitsize"); \
1650-
if (osize == 32) \
1651-
*(float*)pr = a; \
1652-
else if (osize == 64) \
1653-
*(double*)pr = a; \
1654-
else \
1655-
jl_error("fpext: runtime floating point intrinsics are not implemented for bit sizes other than 32 and 64");
1656-
un_fintrinsic_withtype(fptrunc,fptrunc)
1657-
un_fintrinsic_withtype(fpext,fpext)
1627+
#define fintrinsic_read_float16(p) half_to_float(*(uint16_t *)p)
1628+
#define fintrinsic_read_bfloat16(p) bfloat_to_float(*(uint16_t *)p)
1629+
#define fintrinsic_read_float32(p) *(float *)p
1630+
#define fintrinsic_read_float64(p) *(double *)p
1631+
1632+
#define fintrinsic_write_float16(p, x) *(uint16_t *)p = float_to_half(x)
1633+
#define fintrinsic_write_bfloat16(p, x) *(uint16_t *)p = float_to_bfloat(x)
1634+
#define fintrinsic_write_float32(p, x) *(float *)p = x
1635+
#define fintrinsic_write_float64(p, x) *(double *)p = x
1636+
1637+
/*
1638+
* aty: Type of value argument (input)
1639+
* pa: Pointer to value argument data
1640+
* ty: Type argument (output)
1641+
* pr: Pointer to result data
1642+
*/
1643+
1644+
static inline void fptrunc(jl_datatype_t *aty, void *pa, jl_datatype_t *ty, void *pr)
1645+
{
1646+
unsigned isize = jl_datatype_size(aty), osize = jl_datatype_size(ty);
1647+
if (!(osize < isize)) {
1648+
jl_error("fptrunc: output bitsize must be < input bitsize");
1649+
return;
1650+
}
1651+
1652+
#define fptrunc_convert(in, out) \
1653+
else if (aty == jl_##in##_type && ty == jl_##out##_type) \
1654+
fintrinsic_write_##out(pr, fintrinsic_read_##in(pa))
1655+
1656+
if (0)
1657+
;
1658+
fptrunc_convert(float32, float16);
1659+
fptrunc_convert(float64, float16);
1660+
fptrunc_convert(float32, bfloat16);
1661+
fptrunc_convert(float64, bfloat16);
1662+
fptrunc_convert(float64, float32);
1663+
else
1664+
jl_error("fptrunc: runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64");
1665+
#undef fptrunc_convert
1666+
}
1667+
1668+
static inline void fpext(jl_datatype_t *aty, void *pa, jl_datatype_t *ty, void *pr)
1669+
{
1670+
unsigned isize = jl_datatype_size(aty), osize = jl_datatype_size(ty);
1671+
if (!(osize > isize)) {
1672+
jl_error("fpext: output bitsize must be > input bitsize");
1673+
return;
1674+
}
1675+
1676+
#define fpext_convert(in, out) \
1677+
else if (aty == jl_##in##_type && ty == jl_##out##_type) \
1678+
fintrinsic_write_##out(pr, fintrinsic_read_##in(pa))
1679+
1680+
if (0)
1681+
;
1682+
fpext_convert(float16, float32);
1683+
fpext_convert(float16, float64);
1684+
fpext_convert(bfloat16, float32);
1685+
fpext_convert(bfloat16, float64);
1686+
fpext_convert(float32, float64);
1687+
else
1688+
jl_error("fptrunc: runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64");
1689+
#undef fpext_convert
1690+
}
1691+
1692+
cvt_iintrinsic(fptrunc, fptrunc)
1693+
cvt_iintrinsic(fpext, fpext)
1694+
16581695

16591696
// checked arithmetic
16601697
/**

0 commit comments

Comments
 (0)