is there a good way to save the actual values of jax array? #11091
-
Hi all, I am new to JAX and I am trying to save the actual values of my Jax arrays during the computation. Currently, I can view my values by calling the id_print function, but I cannot save them as an array. I saw discussion #3766, but it still does not provide a good solution to save the outputs. Could you give me some suggestions on this? Thank you! I am using the following code: def jax_save(array):
def save_to_file(arg):
jax.numpy.save('test.npy',arg)
id_tap(save_to_file,array)
jax_save(jax_array) Best, |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Depending on how you execute the code, there are a couple ways to do this. If you aren't using JAX transforms like import jax.numpy as jnp
import numpy as np
from jax import jit
def save_no_jit(x):
np.save('test.npy', x)
x = jnp.arange(10)
!rm test.npy
save_no_jit(x)
print(jnp.load('test.npy'))
# [0 1 2 3 4 5 6 7 8 9] However, if you try this within a JIT-compiled function, it will error: jit(save_no_jit)(x)
# TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object For this case, the best option is probably from jax.experimental import host_callback
def save_with_jit(x):
host_callback.call(lambda x: np.save('test.npy', x), x)
save_with_jit(x + 1)
print(jnp.load('test.npy'))
# [ 1 2 3 4 5 6 7 8 9 10]
jit(save_with_jit)(x + 2)
print(jnp.load('test.npy'))
# [ 2 3 4 5 6 7 8 9 10 11] |
Beta Was this translation helpful? Give feedback.
Depending on how you execute the code, there are a couple ways to do this. If you aren't using JAX transforms like
jit
, you can just call the function directly:However, if you try this within a JIT-compiled function, it will error:
For this case, the best option is probably
hostcallback.call
, which works regardless of whether JIT is used: