We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 0c437c4 commit 74e42f9Copy full SHA for 74e42f9
test/test_jax.py
@@ -0,0 +1,32 @@
1
+# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
2
+
3
+import numpy as np
4
5
+import jax
6
+import jax.numpy as jnp
7
8
+import ml_dtypes
9
10
+import gfloat
11
+from gfloat.formats import *
12
13
+jax.config.update("jax_enable_x64", True)
14
15
16
+def test_jax() -> None:
17
+ """
18
+ Test that JAX JIT produces correct output
19
20
+ a = np.random.randn(1024)
21
22
+ a8 = a.astype(ml_dtypes.float8_e5m2).astype(jnp.float64)
23
24
+ fi = format_info_ocp_e5m2
25
+ j8 = gfloat.round_ndarray(fi, jnp.array(a), np=jnp)
26
27
+ np.testing.assert_equal(a8, j8)
28
29
+ jax_round_array = jax.jit(lambda x: gfloat.round_ndarray(fi, x, np=jnp))
30
+ j8i = jax_round_array(a)
31
32
+ np.testing.assert_equal(a8, j8i)
0 commit comments