Issues with saving and pickling bfloat16 array #8494
-
Hi, Saving a
As you can see, the result of There is also a problem when using
Is this a bug? Thank you. jaxlib version is |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
Thanks for the report - the issue is that bfloat16 is not a native numpy type. If you want to recover the original array, you can do so by viewing the output as type import jax.numpy as jnp
a = jnp.array([1, 2, 3], dtype=jnp.bfloat16)
jnp.save('test.npy', a)
b = jnp.load('test.npy').view(jnp.bfloat16)
print(b)
# array([1, 2, 3], dtype=bfloat16) We might think about doing this automatically within the |
Beta Was this translation helpful? Give feedback.
Thanks for the report - the issue is that bfloat16 is not a native numpy type. If you want to recover the original array, you can do so by viewing the output as type
jnp.bfloat16
:We might think about doing this automatically within the
jnp.load
wrapper (currentlyjnp.load = np.load
with no extra logic). Edit: made this change in #8499