Skip to content
Discussion options

You must be logged in to vote

Any time you call a scipy or numpy function on a JAX array, the data is copied from the JAX device (whatever it is) to the CPU as a numpy array, and the return result is a numpy array. When you call a jax function on a numpy array, the data is copied from the CPU to the JAX device as a DeviceArray, and the result is returned as a DeviceArray. Note when JAX's default device is CPU, many of these copies are no-copy views, but not all: some data buffers will be copied. This repeated data movement will lead to slow execution in your example: scipy.optimize.minimize does computation on numpy arrays, and your minimization function does computation on JAX arrays, and the data is passed back and …

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@JiahaoYao
Comment options

@jakevdp
Comment options

@JiahaoYao
Comment options

Answer selected by JiahaoYao
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants