Skip to content

Commit 9828d73

Browse files
committed
Initial implementation
1 parent febff62 commit 9828d73

File tree

3 files changed

+102
-4
lines changed

3 files changed

+102
-4
lines changed

src/gfloat/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .printing import float_pow2str, float_tilde_unless_roundtrip_str
1212
from .round import encode_float, round_float
1313
from .round_ndarray import encode_ndarray, round_ndarray
14+
from .decode_ndarray import decode_ndarray
1415
from .types import FloatClass, FloatValue, FormatInfo, RoundMode
1516

1617
# Don't automatically import from .formats.

src/gfloat/decode_ndarray.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
2+
3+
import numpy as np
4+
5+
from .types import FloatClass, FloatValue, FormatInfo
6+
7+
8+
def decode_ndarray(fi: FormatInfo, codes: np.ndarray, np=np) -> np.ndarray:
9+
r"""
10+
Vectorized version of :function:`decode_float`
11+
12+
Args:
13+
fi (FormatInfo): Floating point format descriptor.
14+
i (array of int): Integer code points, in the range :math:`0 \le i < 2^{k}`,
15+
where :math:`k` = ``fi.k``
16+
17+
Returns:
18+
Decoded float values
19+
20+
Raises:
21+
ValueError:
22+
If any :paramref:`i` is outside the range of valid code points in :paramref:`fi`.
23+
"""
24+
assert np.issubdtype(codes.dtype, np.integer)
25+
26+
k = fi.k
27+
p = fi.precision
28+
t = p - 1 # Trailing significand field width
29+
num_signbits = 1 if fi.is_signed else 0
30+
w = k - t - num_signbits # Exponent field width
31+
32+
if np.any(codes < 0) or np.any(codes >= 2**k):
33+
raise ValueError(f"Code point not in range [0, 2**{k})")
34+
35+
if fi.is_signed:
36+
signmask = 1 << (k - 1)
37+
sign = np.where(codes & signmask, -1.0, 1.0)
38+
else:
39+
signmask = None
40+
sign = 1.0
41+
42+
exp = (codes >> t) & ((1 << w) - 1)
43+
significand = codes & ((1 << t) - 1)
44+
if fi.is_twos_complement:
45+
significand = np.where(sign < 0, (1 << t) - significand, significand)
46+
47+
expBias = fi.expBias
48+
49+
iszero = (exp == 0) & (significand == 0) if fi.has_zero else False
50+
issubnormal = (exp == 0) & (significand != 0) if fi.has_subnormals else False
51+
isnormal = ~iszero & ~issubnormal
52+
expval = np.where(~isnormal, 1 - expBias, exp - expBias)
53+
fsignificand = np.where(~isnormal, significand * 2**-t, 1.0 + significand * 2**-t)
54+
55+
# Normal/Subnormal/Zero case, other values will be overwritten
56+
fval = np.where(iszero, 0.0, sign * fsignificand * 2.0**expval)
57+
58+
# All-bits-special exponent (ABSE)
59+
if w > 0:
60+
abse = exp == 2**w - 1
61+
min_i_with_nan = 2 ** (p - 1) - fi.num_high_nans
62+
fval = np.where(abse & (significand >= min_i_with_nan), np.nan, fval)
63+
if fi.has_infs:
64+
fval = np.where(
65+
abse & (significand == min_i_with_nan - 1), np.inf * sign, fval
66+
)
67+
68+
# Negative zero
69+
if fi.has_nz:
70+
fval = np.where(iszero & (sign < 0), -0.0, fval)
71+
else:
72+
# Negative zero slot is nan
73+
fval = np.where(codes == fi.code_of_negzero, np.nan, fval)
74+
75+
return fval

test/test_decode.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,42 @@
11
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
2-
2+
from typing import Callable
33
import ml_dtypes
44
import numpy as np
55
import pytest
66

7-
from gfloat import FloatClass, decode_float
7+
from gfloat import FloatClass, decode_float, decode_ndarray
88
from gfloat.formats import *
99

1010

1111
def _isnegzero(x: float) -> bool:
1212
return (x == 0) and (np.signbit(x) == 1)
1313

1414

15-
def test_spot_check_ocp_e5m2() -> None:
15+
methods = ["scalar", "array"]
16+
17+
18+
def get_method(method: str, fi: FormatInfo) -> Callable:
19+
if method == "scalar":
20+
21+
def dec(code: int) -> float:
22+
return decode_float(fi, code).fval
23+
24+
if method == "array":
25+
26+
def dec(code: int) -> float:
27+
asnp = np.tile(np.array(code, dtype=np.int64), (2, 3))
28+
vals = decode_ndarray(fi, asnp)
29+
val = vals.item(0)
30+
np.testing.assert_equal(val, vals)
31+
return val
32+
33+
return dec
34+
35+
36+
@pytest.mark.parametrize("method", methods)
37+
def test_spot_check_ocp_e5m2(method) -> None:
1638
fi = format_info_ocp_e5m2
17-
dec = lambda code: decode_float(fi, code).fval
39+
dec = get_method(method, fi)
1840
fclass = lambda code: decode_float(fi, code).fclass
1941
assert dec(0x01) == 2.0**-16
2042
assert dec(0x40) == 2.0

0 commit comments

Comments
 (0)