-
import jax.numpy as np
import numpy as onp
from scipy.optimize import minimize
def run(np):
def rosen(x):
return np.sum(100.0 * (x[1:] - x[:-1] ** 2.0) ** 2.0 + (1 - x[:-1]) ** 2.0)
x0 = np.array([1.3, 0.7, 0.8, 1.9, 1.2, 1.3, 0.7, 0.8, 1.9, 1.2], dtype='float32')
bounds = np.array([[0.7, 1.3]] * 10)
result = minimize(rosen, x0, method='SLSQP', options={'ftol': 1e-9, 'disp': True})
import time
t1 = time.time()
run(onp)
t2 = time.time()
print(t2 - t1)
t1 = time.time()
run(np)
t2 = time.time()
print(t2 - t1) Here is the code from #936, and when I compare the speed of jax and numpy and find jax is 50x slower? Can anyone know how the jax code can be faster? Here is the result:
Or is this supposed to be so slow? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Any time you call a scipy or numpy function on a JAX array, the data is copied from the JAX device (whatever it is) to the CPU as a numpy array, and the return result is a numpy array. When you call a jax function on a numpy array, the data is copied from the CPU to the JAX device as a DeviceArray, and the result is returned as a DeviceArray. Note when JAX's default device is CPU, many of these copies are no-copy views, but not all: some data buffers will be copied. This repeated data movement will lead to slow execution in your example: Instead, you should try JAX-native minimization, such as See a related discussion here: #5292 |
Beta Was this translation helpful? Give feedback.
Any time you call a scipy or numpy function on a JAX array, the data is copied from the JAX device (whatever it is) to the CPU as a numpy array, and the return result is a numpy array. When you call a jax function on a numpy array, the data is copied from the CPU to the JAX device as a DeviceArray, and the result is returned as a DeviceArray. Note when JAX's default device is CPU, many of these copies are no-copy views, but not all: some data buffers will be copied. This repeated data movement will lead to slow execution in your example:
scipy.optimize.minimize
does computation on numpy arrays, and your minimization function does computation on JAX arrays, and the data is passed back and …