Skip to content

Commit 74e42f9

Browse files
committed
Add JAX test
1 parent 0c437c4 commit 74e42f9

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

test/test_jax.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
2+
3+
import numpy as np
4+
5+
import jax
6+
import jax.numpy as jnp
7+
8+
import ml_dtypes
9+
10+
import gfloat
11+
from gfloat.formats import *
12+
13+
jax.config.update("jax_enable_x64", True)
14+
15+
16+
def test_jax() -> None:
17+
"""
18+
Test that JAX JIT produces correct output
19+
"""
20+
a = np.random.randn(1024)
21+
22+
a8 = a.astype(ml_dtypes.float8_e5m2).astype(jnp.float64)
23+
24+
fi = format_info_ocp_e5m2
25+
j8 = gfloat.round_ndarray(fi, jnp.array(a), np=jnp)
26+
27+
np.testing.assert_equal(a8, j8)
28+
29+
jax_round_array = jax.jit(lambda x: gfloat.round_ndarray(fi, x, np=jnp))
30+
j8i = jax_round_array(a)
31+
32+
np.testing.assert_equal(a8, j8i)

0 commit comments

Comments
 (0)