-
I am comparing the speed between import jax.numpy as np
from jax import jit
from jax.scipy.optimize import minimize
import numpy as onp
from jax.scipy.optimize import minimize as ominimize
from jax.config import config
config.update("jax_enable_x64", True)
import time
_x0 = onp.random.randn(1000)
@jit
def rosen(x):
return np.sum(100.0 * (x[1:] - x[:-1] ** 2.0) ** 2.0 + (1 - x[:-1]) ** 2.0)
x0 = np.asarray(_x0)
t1 = time.time()
result = minimize(rosen, x0, method='BFGS')
# print(result.x)
print(rosen(result.x))
t2 = time.time()
print(t2-t1)
def rosen(x):
return onp.sum(100.0 * (x[1:] - x[:-1] ** 2.0) ** 2.0 + (1 - x[:-1]) ** 2.0)
x0 = _x0
t1 = time.time()
result = ominimize(rosen, x0, method='BFGS')
# print(result.x)
print(rosen(result.x))
t2 = time.time()
print(t2-t1)
# the results
# : numpy 27.972447872161865
# : jax 31.88966989517212
# seems like the jax does not give lots of speedups |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Jun 3, 2021
Replies: 1 comment 1 reply
-
Thanks for the question! You might find some useful information in previous discussions of this topic: #6568 #5993 #5876 |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
JiahaoYao
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks for the question! You might find some useful information in previous discussions of this topic: #6568 #5993 #5876