Skip to content

Commit 00515b6

Browse files
committed
Improve timing presentation
1 parent 04cde87 commit 00515b6

File tree

1 file changed

+37
-31
lines changed

1 file changed

+37
-31
lines changed

docs/source/05-speed.ipynb

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,18 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 1,
5+
"execution_count": 12,
66
"metadata": {},
7-
"outputs": [],
7+
"outputs": [
8+
{
9+
"name": "stdout",
10+
"output_type": "stream",
11+
"text": [
12+
"The autoreload extension is already loaded. To reload it, use:\n",
13+
" %reload_ext autoreload\n"
14+
]
15+
}
16+
],
817
"source": [
918
"# Copyright (c) 2024 Graphcore Ltd. All rights reserved.\n",
1019
"\n",
@@ -14,6 +23,9 @@
1423
"import ml_dtypes\n",
1524
"import gfloat\n",
1625
"from gfloat.formats import format_info_ocp_e5m2\n",
26+
"from timeit import Timer\n",
27+
"\n",
28+
"jax.config.update(\"jax_enable_x64\", True)\n",
1729
"\n",
1830
"%load_ext autoreload\n",
1931
"%autoreload 2"
@@ -32,31 +44,22 @@
3244
},
3345
{
3446
"cell_type": "code",
35-
"execution_count": 2,
47+
"execution_count": 24,
3648
"metadata": {},
3749
"outputs": [
38-
{
39-
"name": "stderr",
40-
"output_type": "stream",
41-
"text": [
42-
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n",
43-
"/home/awf/.micromamba/envs/gfloat-clean/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:68: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.\n",
44-
" return lax_numpy.astype(arr, dtype, copy=copy, device=device)\n"
45-
]
46-
},
4750
{
4851
"name": "stdout",
4952
"output_type": "stream",
5053
"text": [
51-
"GFloat scalar :616 ms ± 23.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
52-
"GFloat vectorized, numpy arrays:4.49 ms ± 255 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n",
53-
"GFloat vectorized, JAX JIT :596 µs ± 13.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n",
54-
"ML_dtypes :266 µs ± 16.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
54+
"GFloat scalar : 6996.75 nsec (50 runs at size 10000)\n",
55+
"GFloat vectorized, numpy arrays: 75.04 nsec (50 runs at size 1000000)\n",
56+
"GFloat vectorized, JAX JIT : 3.18 nsec (1000 runs at size 1000000)\n",
57+
"ML_dtypes : 3.13 nsec (1000 runs at size 1000000)\n"
5558
]
5659
}
5760
],
5861
"source": [
59-
"N = 100_000\n",
62+
"N = 1_000_000\n",
6063
"a = np.random.rand(N)\n",
6164
"\n",
6265
"jax_round_jit = jax.jit(lambda x: gfloat.round_ndarray(format_info_ocp_e5m2, x, np=jnp))\n",
@@ -68,17 +71,20 @@
6871
" return np.array([gfloat.round_float(fi, x) for x in a])\n",
6972
"\n",
7073
"\n",
71-
"print(\"GFloat scalar :\", end=\"\")\n",
72-
"%timeit slow_round_ndarray(format_info_ocp_e5m2, a)\n",
73-
"\n",
74-
"print(\"GFloat vectorized, numpy arrays:\", end=\"\")\n",
75-
"%timeit gfloat.round_ndarray(format_info_ocp_e5m2, a)\n",
74+
"def time(f, problem_size=1.0):\n",
75+
" units = 1e9 # nsec\n",
76+
" t = Timer(f)\n",
77+
" n = t.autorange()[0] * 10 # About 2 sec per run\n",
78+
" ts = t.repeat(repeat=3, number=n) # best of 3\n",
79+
" ts = [((t / n) / problem_size) * units for t in ts] # per run\n",
80+
" return f\"{min(ts):8.2f} nsec ({n} runs at size {problem_size})\"\n",
7681
"\n",
77-
"print(\"GFloat vectorized, JAX JIT :\", end=\"\")\n",
78-
"%timeit jax_round_jit(ja)\n",
7982
"\n",
80-
"print(\"ML_dtypes :\", end=\"\")\n",
81-
"%timeit a.astype(ml_dtypes.float8_e5m2)"
83+
"# fmt: off\n",
84+
"print(\"GFloat scalar :\", time(lambda: slow_round_ndarray(format_info_ocp_e5m2, a[: N // 100]), N // 100))\n",
85+
"print(\"GFloat vectorized, numpy arrays:\", time(lambda: gfloat.round_ndarray(format_info_ocp_e5m2, a), N))\n",
86+
"print(\"GFloat vectorized, JAX JIT :\", time(lambda: jax_round_jit(ja), N))\n",
87+
"print(\"ML_dtypes :\", time(lambda: a.astype(ml_dtypes.float8_e5m2), N))"
8288
]
8389
},
8490
{
@@ -87,12 +93,12 @@
8793
"source": [
8894
"On one CPU platform the timings were:\n",
8995
"```\n",
90-
"GFloat scalar :629 ms ± 22.3 ms \n",
91-
"GFloat vectorized, numpy arrays: 4.420 ms ± 153 µs \n",
92-
"GFloat vectorized, JAX JIT : 585 µs ± 13.7 µs \n",
93-
"ML_dtypes : 253 µs ± 12 µs \n",
96+
"GFloat scalar : 6996.75 nsec (50 runs at size 10000)\n",
97+
"GFloat vectorized, numpy arrays: 75.04 nsec (50 runs at size 1000000)\n",
98+
"GFloat vectorized, JAX JIT : 3.18 nsec (1000 runs at size 1000000)\n",
99+
"ML_dtypes : 3.13 nsec (1000 runs at size 1000000)\n",
94100
"```\n",
95-
"So the JAX JIT code is 1000x faster than the scalar code, although `ml_dtypes`'s C++ is 2-3x faster still."
101+
"So the JAX JIT code is ~2000x faster than the scalar code, and comparable to `ml_dtypes`'s C++ CPU implementation."
96102
]
97103
}
98104
],

0 commit comments

Comments
 (0)