lax.scan converts JAX arrays into Numpy arrays #21011
-
Hi all, I have a question regarding the behavior of lax.scan. In the following example code, the JAX array
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hi - thanks for the question! I believe it's actually In either case, the actual computation is not converting any inputs to NumPy arrays. |
Beta Was this translation helpful? Give feedback.
Hi - thanks for the question! I believe it's actually
debug.callback
and the utilities built on top of it (i.e.debug.print
,debug.breakpoint
) that is converting the JAX array to a NumPy array here, and this is by design in JAX v0.4.26 and earlier. But note that this will change in the next JAX release, once #20809 is part of the release.In either case, the actual computation is not converting any inputs to NumPy arrays.