Skip to content

Commit b2bedf7

Browse files
authored
Merge pull request #29 from graphcore-research/ndarray
Add vectorized versions of round/encode/decode.
2 parents ac301f5 + c672a97 commit b2bedf7

File tree

13 files changed

+625
-64
lines changed

13 files changed

+625
-64
lines changed

docs/source/04-benchmark.ipynb

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"# Copyright (c) 2024 Graphcore Ltd. All rights reserved.\n",
10+
"\n",
11+
"import numpy as np\n",
12+
"import jax\n",
13+
"import jax.numpy as jnp\n",
14+
"import ml_dtypes\n",
15+
"import gfloat\n",
16+
"from gfloat.formats import format_info_ocp_e5m2\n",
17+
"from timeit import Timer\n",
18+
"\n",
19+
"jax.config.update(\"jax_enable_x64\", True)"
20+
]
21+
},
22+
{
23+
"cell_type": "markdown",
24+
"metadata": {},
25+
"source": [
26+
"# Timing tests\n",
27+
"\n",
28+
"The `gfloat` library is designed for readability over performance, and the reference code for computations is the (slow) scalar code e.g. `round_float`.\n",
29+
"\n",
30+
"There are vectorized implementations (e.g. `round_ndarray`), and when combined with JAX, these can go reasonably fast.\n",
31+
"\n",
32+
"Let's see how long it takes to encode some values to FP8..."
33+
]
34+
},
35+
{
36+
"cell_type": "code",
37+
"execution_count": 2,
38+
"metadata": {},
39+
"outputs": [
40+
{
41+
"name": "stderr",
42+
"output_type": "stream",
43+
"text": [
44+
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
45+
]
46+
},
47+
{
48+
"name": "stdout",
49+
"output_type": "stream",
50+
"text": [
51+
"GFloat scalar : 6062.08 nsec (25 runs at size 10000)\n",
52+
"GFloat vectorized, numpy arrays: 53.39 nsec (25 runs at size 1000000)\n",
53+
"GFloat vectorized, JAX JIT : 3.48 nsec (500 runs at size 1000000)\n",
54+
"ML_dtypes : 3.27 nsec (500 runs at size 1000000)\n"
55+
]
56+
}
57+
],
58+
"source": [
59+
"N = 1_000_000\n",
60+
"a = np.random.rand(N)\n",
61+
"\n",
62+
"jax_round_jit = jax.jit(lambda x: gfloat.round_ndarray(format_info_ocp_e5m2, x, np=jnp))\n",
63+
"ja = jnp.array(a)\n",
64+
"jax_round_jit(ja) # Cache compilation\n",
65+
"\n",
66+
"\n",
67+
"def slow_round_ndarray(fi, a):\n",
68+
" return np.array([gfloat.round_float(fi, x) for x in a])\n",
69+
"\n",
70+
"\n",
71+
"# About how many seconds to run for (autorange will take at least .2 sec)\n",
72+
"ACCURACY = 1.0\n",
73+
"\n",
74+
"\n",
75+
"def time(f, problem_size=1.0):\n",
76+
" units = 1e9 # nsec\n",
77+
" t = Timer(f)\n",
78+
" f() # pre-run\n",
79+
" n = int(t.autorange()[0] * ACCURACY / 0.2)\n",
80+
" ts = t.repeat(repeat=3, number=n) # best of 3\n",
81+
" ts = [((t / n) / problem_size) * units for t in ts] # per run\n",
82+
" return f\"{min(ts):8.2f} nsec ({n} runs at size {problem_size})\"\n",
83+
"\n",
84+
"\n",
85+
"# fmt: off\n",
86+
"print(\"GFloat scalar :\", time(lambda: slow_round_ndarray(format_info_ocp_e5m2, a[: N // 100]), N // 100))\n",
87+
"print(\"GFloat vectorized, numpy arrays:\", time(lambda: gfloat.round_ndarray(format_info_ocp_e5m2, a), N))\n",
88+
"print(\"GFloat vectorized, JAX JIT :\", time(lambda: jax_round_jit(ja), N))\n",
89+
"print(\"ML_dtypes :\", time(lambda: a.astype(ml_dtypes.float8_e5m2), N))"
90+
]
91+
},
92+
{
93+
"cell_type": "markdown",
94+
"metadata": {},
95+
"source": [
96+
"On one CPU platform the timings were:\n",
97+
"```\n",
98+
"GFloat scalar : 6996.75 nsec (50 runs at size 10000)\n",
99+
"GFloat vectorized, numpy arrays: 75.04 nsec (50 runs at size 1000000)\n",
100+
"GFloat vectorized, JAX JIT : 3.18 nsec (1000 runs at size 1000000)\n",
101+
"ML_dtypes : 3.13 nsec (1000 runs at size 1000000)\n",
102+
"```\n",
103+
"So the JAX JIT code is ~1000x faster than the scalar code, and comparable to `ml_dtypes`'s C++ CPU implementation."
104+
]
105+
}
106+
],
107+
"metadata": {
108+
"kernelspec": {
109+
"display_name": "Python 3",
110+
"language": "python",
111+
"name": "python3"
112+
},
113+
"language_info": {
114+
"codemirror_mode": {
115+
"name": "ipython",
116+
"version": 3
117+
},
118+
"file_extension": ".py",
119+
"mimetype": "text/x-python",
120+
"name": "python",
121+
"nbconvert_exporter": "python",
122+
"pygments_lexer": "ipython3",
123+
"version": "3.10.0"
124+
}
125+
},
126+
"nbformat": 4,
127+
"nbformat_minor": 2
128+
}

docs/source/api.rst

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,16 @@ API
88
Scalar Functions
99
----------------
1010

11-
.. autofunction:: decode_float
1211
.. autofunction:: round_float
1312
.. autofunction:: encode_float
13+
.. autofunction:: decode_float
14+
15+
Array Functions
16+
---------------
17+
18+
.. autofunction:: round_ndarray
19+
.. autofunction:: encode_ndarray
20+
.. autofunction:: decode_ndarray
1421

1522
Block format functions
1623
----------------------

docs/source/conf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,6 @@
5252

5353
# -- Options for EPUB output
5454
epub_show_urls = "footnote"
55+
56+
# -- Options for myst_nb
57+
nb_execution_mode = "off"

docs/source/index.rst

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,32 @@ of:
1717
* Precision (p)
1818
* Maximum exponent (emax)
1919

20-
with additional fields defining the encoding of infinities, Not-a-number (NaN) values,
21-
and negative zero, among others (see :class:`gfloat.FormatInfo`.)
20+
with additional fields defining the presence/encoding of:
21+
22+
* Infinities
23+
* Not-a-number (NaN) values
24+
* Negative zero
25+
* Subnormal numbers
26+
* Signed/unsigned
27+
* Two's complement encoding (of the significand)
2228

2329
This allows an implementation of generic floating point encode/decode logic,
2430
handling various current and proposed floating point types:
2531

2632
- `IEEE 754 <https://en.wikipedia.org/wiki/IEEE_754>`_: Binary16, Binary32
27-
- `OCP Float8 <https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-06-20-pdf>`_: E5M2, E4M3, and MX formats
33+
- `Brain floating point <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format>`_: BFloat16
34+
- `OCP Float8 <https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-06-20-pdf>`_: E5M2, E4M3
2835
- `IEEE WG P3109 <https://github.com/awf/P3109-Public/blob/main/Shared%20Reports/P3109%20WG%20Interim%20report.pdf>`_: P{p} for p in 1..7
36+
- Types from the `OCP MX <https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf>`_ spec: E8M0, INT8, and FP4, FP6 types
37+
2938

30-
The library favours readability and extensibility over speed - for fast
31-
implementations of these datatypes see, for example,
39+
GFloat, being a pure Python library, favours readability and extensibility over speed
40+
(although the `*_ndarray` functions are reasonably fast for large arrays).
41+
For fast implementations of these datatypes see, for example,
3242
`ml_dtypes <https://github.com/jax-ml/ml_dtypes>`_,
3343
`bitstring <https://github.com/scott-griffiths/bitstring>`_,
34-
`MX PyTorch Emulation Library <https://github.com/microsoft/microxcaling>`_.
44+
`MX PyTorch Emulation Library <https://github.com/microsoft/microxcaling>`_,
45+
`APyTypes <https://apytypes.github.io/apytypes>`_.
3546

3647
To get started with the library, we recommend perusing the notebooks,
3748
otherwise you may wish to jump straight into the API.

docs/source/notebooks.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ Some notebooks to illustrate uses of the library
1111
01-decode.ipynb
1212
02-value-stats.ipynb
1313
03-value-tables.ipynb
14+
04-benchmark.ipynb

requirements-dev.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,5 @@ myst_nb
1919
# Requirements for notebooks
2020
airium
2121
pandas
22+
jaxlib
23+
jax

src/gfloat/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from .decode import decode_float
1111
from .printing import float_pow2str, float_tilde_unless_roundtrip_str
1212
from .round import encode_float, round_float
13+
from .round_ndarray import encode_ndarray, round_ndarray
14+
from .decode_ndarray import decode_ndarray
1315
from .types import FloatClass, FloatValue, FormatInfo, RoundMode
1416

1517
# Don't automatically import from .formats.

src/gfloat/decode_ndarray.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
2+
3+
from types import ModuleType
4+
import numpy as np
5+
from .types import FormatInfo
6+
7+
8+
def decode_ndarray(
9+
fi: FormatInfo, codes: np.ndarray, np: ModuleType = np
10+
) -> np.ndarray:
11+
r"""
12+
Vectorized version of :meth:`decode_float`
13+
14+
Args:
15+
fi (FormatInfo): Floating point format descriptor.
16+
i (array of int): Integer code points, in the range :math:`0 \le i < 2^{k}`,
17+
where :math:`k` = ``fi.k``
18+
19+
Returns:
20+
Decoded float values
21+
22+
Raises:
23+
ValueError:
24+
If any :paramref:`i` is outside the range of valid code points in :paramref:`fi`.
25+
"""
26+
assert np.issubdtype(codes.dtype, np.integer)
27+
28+
k = fi.k
29+
p = fi.precision
30+
t = p - 1 # Trailing significand field width
31+
num_signbits = 1 if fi.is_signed else 0
32+
w = k - t - num_signbits # Exponent field width
33+
34+
if np.any(codes < 0) or np.any(codes >= 2**k):
35+
raise ValueError(f"Code point not in range [0, 2**{k})")
36+
37+
if fi.is_signed:
38+
signmask = 1 << (k - 1)
39+
sign = np.where(codes & signmask, -1.0, 1.0)
40+
else:
41+
signmask = None
42+
sign = 1.0
43+
44+
exp = ((codes >> t) & ((1 << w) - 1)).astype(np.int64)
45+
significand = codes & ((1 << t) - 1)
46+
if fi.is_twos_complement:
47+
significand = np.where(sign < 0, (1 << t) - significand, significand)
48+
49+
expBias = fi.expBias
50+
51+
iszero = (exp == 0) & (significand == 0) & fi.has_zero
52+
issubnormal = (exp == 0) & (significand != 0) & fi.has_subnormals
53+
isnormal = ~iszero & ~issubnormal
54+
expval = np.where(~isnormal, 1 - expBias, exp - expBias)
55+
fsignificand = np.where(~isnormal, significand * 2**-t, 1.0 + significand * 2**-t)
56+
57+
# Normal/Subnormal/Zero case, other values will be overwritten
58+
fval = np.where(iszero, 0.0, sign * fsignificand * 2.0**expval)
59+
60+
if fi.has_infs:
61+
fval = np.where(codes == fi.code_of_posinf, np.inf, fval)
62+
fval = np.where(codes == fi.code_of_neginf, -np.inf, fval)
63+
64+
if fi.num_nans > 0:
65+
code_is_nan = codes == fi.code_of_nan
66+
if w > 0:
67+
# All-bits-special exponent (ABSE)
68+
abse = exp == 2**w - 1
69+
min_code_with_nan = 2 ** (p - 1) - fi.num_high_nans
70+
code_is_nan |= abse & (significand >= min_code_with_nan)
71+
72+
fval = np.where(code_is_nan, np.nan, fval)
73+
74+
# Negative zero
75+
if fi.has_nz:
76+
fval = np.where(iszero & (sign < 0), -0.0, fval)
77+
78+
return fval

0 commit comments

Comments
 (0)