|
2 | 2 | "cells": [
|
3 | 3 | {
|
4 | 4 | "cell_type": "code",
|
5 |
| - "execution_count": 2, |
| 5 | + "execution_count": 1, |
6 | 6 | "metadata": {},
|
7 | 7 | "outputs": [],
|
8 | 8 | "source": [
|
|
16 | 16 | "from gfloat.formats import format_info_ocp_e5m2\n",
|
17 | 17 | "from timeit import Timer\n",
|
18 | 18 | "\n",
|
19 |
| - "jax.config.update(\"jax_enable_x64\", True)\n", |
20 |
| - "\n", |
21 |
| - "%load_ext autoreload\n", |
22 |
| - "%autoreload 2" |
| 19 | + "jax.config.update(\"jax_enable_x64\", True)" |
23 | 20 | ]
|
24 | 21 | },
|
25 | 22 | {
|
|
37 | 34 | },
|
38 | 35 | {
|
39 | 36 | "cell_type": "code",
|
40 |
| - "execution_count": 5, |
| 37 | + "execution_count": 2, |
41 | 38 | "metadata": {},
|
42 | 39 | "outputs": [
|
| 40 | + { |
| 41 | + "name": "stderr", |
| 42 | + "output_type": "stream", |
| 43 | + "text": [ |
| 44 | + "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" |
| 45 | + ] |
| 46 | + }, |
43 | 47 | {
|
44 | 48 | "name": "stdout",
|
45 | 49 | "output_type": "stream",
|
46 | 50 | "text": [
|
47 |
| - "GFloat scalar : 7518.31 nsec (5 runs at size 10000)\n", |
48 |
| - "GFloat vectorized, numpy arrays: 57.95 nsec (5 runs at size 1000000)\n", |
49 |
| - "GFloat vectorized, JAX JIT : 4.03 nsec (100 runs at size 1000000)\n", |
50 |
| - "ML_dtypes : 3.34 nsec (100 runs at size 1000000)\n" |
| 51 | + "GFloat scalar : 6062.08 nsec (25 runs at size 10000)\n", |
| 52 | + "GFloat vectorized, numpy arrays: 53.39 nsec (25 runs at size 1000000)\n", |
| 53 | + "GFloat vectorized, JAX JIT : 3.48 nsec (500 runs at size 1000000)\n", |
| 54 | + "ML_dtypes : 3.27 nsec (500 runs at size 1000000)\n" |
51 | 55 | ]
|
52 | 56 | }
|
53 | 57 | ],
|
|
65 | 69 | "\n",
|
66 | 70 | "\n",
|
67 | 71 | "# About how many seconds to run for (autorange will take at least .2 sec)\n",
|
68 |
| - "ACCURACY = 0.2\n", |
| 72 | + "ACCURACY = 1.0\n", |
69 | 73 | "\n",
|
70 | 74 | "\n",
|
71 | 75 | "def time(f, problem_size=1.0):\n",
|
|
0 commit comments