Skip to content
Discussion options

You must be logged in to vote

Thanks for the report - the issue is that bfloat16 is not a native numpy type. If you want to recover the original array, you can do so by viewing the output as type jnp.bfloat16:

import jax.numpy as jnp

a = jnp.array([1, 2, 3], dtype=jnp.bfloat16)
jnp.save('test.npy', a)
b = jnp.load('test.npy').view(jnp.bfloat16)
print(b)
# array([1, 2, 3], dtype=bfloat16)

We might think about doing this automatically within the jnp.load wrapper (currently jnp.load = np.load with no extra logic). Edit: made this change in #8499

Replies: 1 comment 4 replies

Comment options

You must be logged in to vote
4 replies
@ivandustin
Comment options

@ivandustin
Comment options

@grahamgower
Comment options

@jakevdp
Comment options

Answer selected by ivandustin
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants