Skip to content
Discussion options

You must be logged in to vote

Hi - thanks for the question! I believe it's actually debug.callback and the utilities built on top of it (i.e. debug.print, debug.breakpoint) that is converting the JAX array to a NumPy array here, and this is by design in JAX v0.4.26 and earlier. But note that this will change in the next JAX release, once #20809 is part of the release.

In either case, the actual computation is not converting any inputs to NumPy arrays.

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@ToshiyukiBandai
Comment options

Answer selected by ToshiyukiBandai
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants