|
2 | 2 | "cells": [
|
3 | 3 | {
|
4 | 4 | "cell_type": "code",
|
5 |
| - "execution_count": 1, |
| 5 | + "execution_count": 12, |
6 | 6 | "metadata": {},
|
7 |
| - "outputs": [], |
| 7 | + "outputs": [ |
| 8 | + { |
| 9 | + "name": "stdout", |
| 10 | + "output_type": "stream", |
| 11 | + "text": [ |
| 12 | + "The autoreload extension is already loaded. To reload it, use:\n", |
| 13 | + " %reload_ext autoreload\n" |
| 14 | + ] |
| 15 | + } |
| 16 | + ], |
8 | 17 | "source": [
|
9 | 18 | "# Copyright (c) 2024 Graphcore Ltd. All rights reserved.\n",
|
10 | 19 | "\n",
|
|
14 | 23 | "import ml_dtypes\n",
|
15 | 24 | "import gfloat\n",
|
16 | 25 | "from gfloat.formats import format_info_ocp_e5m2\n",
|
| 26 | + "from timeit import Timer\n", |
| 27 | + "\n", |
| 28 | + "jax.config.update(\"jax_enable_x64\", True)\n", |
17 | 29 | "\n",
|
18 | 30 | "%load_ext autoreload\n",
|
19 | 31 | "%autoreload 2"
|
|
32 | 44 | },
|
33 | 45 | {
|
34 | 46 | "cell_type": "code",
|
35 |
| - "execution_count": 2, |
| 47 | + "execution_count": 24, |
36 | 48 | "metadata": {},
|
37 | 49 | "outputs": [
|
38 |
| - { |
39 |
| - "name": "stderr", |
40 |
| - "output_type": "stream", |
41 |
| - "text": [ |
42 |
| - "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n", |
43 |
| - "/home/awf/.micromamba/envs/gfloat-clean/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:68: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.\n", |
44 |
| - " return lax_numpy.astype(arr, dtype, copy=copy, device=device)\n" |
45 |
| - ] |
46 |
| - }, |
47 | 50 | {
|
48 | 51 | "name": "stdout",
|
49 | 52 | "output_type": "stream",
|
50 | 53 | "text": [
|
51 |
| - "GFloat scalar :616 ms ± 23.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n", |
52 |
| - "GFloat vectorized, numpy arrays:4.49 ms ± 255 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", |
53 |
| - "GFloat vectorized, JAX JIT :596 µs ± 13.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n", |
54 |
| - "ML_dtypes :266 µs ± 16.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" |
| 54 | + "GFloat scalar : 6996.75 nsec (50 runs at size 10000)\n", |
| 55 | + "GFloat vectorized, numpy arrays: 75.04 nsec (50 runs at size 1000000)\n", |
| 56 | + "GFloat vectorized, JAX JIT : 3.18 nsec (1000 runs at size 1000000)\n", |
| 57 | + "ML_dtypes : 3.13 nsec (1000 runs at size 1000000)\n" |
55 | 58 | ]
|
56 | 59 | }
|
57 | 60 | ],
|
58 | 61 | "source": [
|
59 |
| - "N = 100_000\n", |
| 62 | + "N = 1_000_000\n", |
60 | 63 | "a = np.random.rand(N)\n",
|
61 | 64 | "\n",
|
62 | 65 | "jax_round_jit = jax.jit(lambda x: gfloat.round_ndarray(format_info_ocp_e5m2, x, np=jnp))\n",
|
|
68 | 71 | " return np.array([gfloat.round_float(fi, x) for x in a])\n",
|
69 | 72 | "\n",
|
70 | 73 | "\n",
|
71 |
| - "print(\"GFloat scalar :\", end=\"\")\n", |
72 |
| - "%timeit slow_round_ndarray(format_info_ocp_e5m2, a)\n", |
73 |
| - "\n", |
74 |
| - "print(\"GFloat vectorized, numpy arrays:\", end=\"\")\n", |
75 |
| - "%timeit gfloat.round_ndarray(format_info_ocp_e5m2, a)\n", |
| 74 | + "def time(f, problem_size=1.0):\n", |
| 75 | + " units = 1e9 # nsec\n", |
| 76 | + " t = Timer(f)\n", |
| 77 | + " n = t.autorange()[0] * 10 # About 2 sec per run\n", |
| 78 | + " ts = t.repeat(repeat=3, number=n) # best of 3\n", |
| 79 | + " ts = [((t / n) / problem_size) * units for t in ts] # per run\n", |
| 80 | + " return f\"{min(ts):8.2f} nsec ({n} runs at size {problem_size})\"\n", |
76 | 81 | "\n",
|
77 |
| - "print(\"GFloat vectorized, JAX JIT :\", end=\"\")\n", |
78 |
| - "%timeit jax_round_jit(ja)\n", |
79 | 82 | "\n",
|
80 |
| - "print(\"ML_dtypes :\", end=\"\")\n", |
81 |
| - "%timeit a.astype(ml_dtypes.float8_e5m2)" |
| 83 | + "# fmt: off\n", |
| 84 | + "print(\"GFloat scalar :\", time(lambda: slow_round_ndarray(format_info_ocp_e5m2, a[: N // 100]), N // 100))\n", |
| 85 | + "print(\"GFloat vectorized, numpy arrays:\", time(lambda: gfloat.round_ndarray(format_info_ocp_e5m2, a), N))\n", |
| 86 | + "print(\"GFloat vectorized, JAX JIT :\", time(lambda: jax_round_jit(ja), N))\n", |
| 87 | + "print(\"ML_dtypes :\", time(lambda: a.astype(ml_dtypes.float8_e5m2), N))" |
82 | 88 | ]
|
83 | 89 | },
|
84 | 90 | {
|
|
87 | 93 | "source": [
|
88 | 94 | "On one CPU platform the timings were:\n",
|
89 | 95 | "```\n",
|
90 |
| - "GFloat scalar :629 ms ± 22.3 ms \n", |
91 |
| - "GFloat vectorized, numpy arrays: 4.420 ms ± 153 µs \n", |
92 |
| - "GFloat vectorized, JAX JIT : 585 µs ± 13.7 µs \n", |
93 |
| - "ML_dtypes : 253 µs ± 12 µs \n", |
| 96 | + "GFloat scalar : 6996.75 nsec (50 runs at size 10000)\n", |
| 97 | + "GFloat vectorized, numpy arrays: 75.04 nsec (50 runs at size 1000000)\n", |
| 98 | + "GFloat vectorized, JAX JIT : 3.18 nsec (1000 runs at size 1000000)\n", |
| 99 | + "ML_dtypes : 3.13 nsec (1000 runs at size 1000000)\n", |
94 | 100 | "```\n",
|
95 |
| - "So the JAX JIT code is 1000x faster than the scalar code, although `ml_dtypes`'s C++ is 2-3x faster still." |
| 101 | + "So the JAX JIT code is ~2000x faster than the scalar code, and comparable to `ml_dtypes`'s C++ CPU implementation." |
96 | 102 | ]
|
97 | 103 | }
|
98 | 104 | ],
|
|
0 commit comments