Skip to content
Discussion options

You must be logged in to vote

Try adding the line import jax.numpy as jnp right before you run the breakpoint. I'd also suggest using numpy not jax.numpy if you're on an accelerator, since the breakpoint uses NumPy arrays, not JAX arrays.

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@HHalva
Comment options

Answer selected by HHalva
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants