From traced array to numpy array #12969
-
I'm using Jax for a Federated Learning project, I want to check a few values inside different traced arrays and compere with values inside numpy arrays. So I have the following questions:
The best that I found is Thanks. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Edit: there is now a FAQ entry on this topic: https://jax.readthedocs.io/en/latest/faq.html#how-can-i-convert-a-jax-tracer-to-a-numpy-array Thanks for the question! You can print traced values at runtime using import jax
import jax.numpy as jnp
@jax.jit
def f(x):
print("traced:", x[0])
jax.debug.print("debug: {}", x[0])
f(jnp.arange(10))
|
Beta Was this translation helpful? Give feedback.
Edit: there is now a FAQ entry on this topic: https://jax.readthedocs.io/en/latest/faq.html#how-can-i-convert-a-jax-tracer-to-a-numpy-array
Thanks for the question! You can print traced values at runtime using
jax.debug.print
. Note however that the syntax ofdebug.print
is different than that of normal print statements: the first argument must be a format string. It might look something like this: