Skip to content
Discussion options

You must be logged in to vote

Depending on how you execute the code, there are a couple ways to do this. If you aren't using JAX transforms like jit, you can just call the function directly:

import jax.numpy as jnp
import numpy as np
from jax import jit

def save_no_jit(x):
  np.save('test.npy', x)

x = jnp.arange(10)

!rm test.npy
save_no_jit(x)

print(jnp.load('test.npy'))
# [0 1 2 3 4 5 6 7 8 9]

However, if you try this within a JIT-compiled function, it will error:

jit(save_no_jit)(x)
# TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object 

For this case, the best option is probably hostcallback.call, which works regardless of whether JIT is used:

from jax.e…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@tianaidong
Comment options

Answer selected by tianaidong
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants