Skip to content
Discussion options

You must be logged in to vote

Edit: there is now a FAQ entry on this topic: https://jax.readthedocs.io/en/latest/faq.html#how-can-i-convert-a-jax-tracer-to-a-numpy-array


Thanks for the question! You can print traced values at runtime using jax.debug.print. Note however that the syntax of debug.print is different than that of normal print statements: the first argument must be a format string. It might look something like this:

import jax
import jax.numpy as jnp

@jax.jit
def f(x):
  print("traced:", x[0])
  jax.debug.print("debug: {}", x[0])

f(jnp.arange(10))
traced: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>
debug: 0

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@mattiadutto
Comment options

@jakevdp
Comment options

@mattiadutto
Comment options

Answer selected by mattiadutto
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants