-
I would like to use jax to accelerate the scipy optimize, but gives wrong answers. Here is the jax-scipy notebook of the MWE, with the original problem of fun = lambda x: (x[0] - 1)**2 + (x[1] - 2.5)**2
res = minimize(fun, (2, 0), method='SLSQP')
res The results are
i am changing this to jax @jit
def jax_fun(x):
return (x[0] - 1)**2 + (x[1] - 2.5)**2
res = minimize(lambda x: np.asarray(jax_fun(jnp.asarray(x))), np.array([2, 0]), method='SLSQP', options={'disp': True})
res The results is
I would appreciate if someone can point out the problem, or if anyone has the same experience before. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
I think the issue is that JAX's default computation is 32-bit, while numpy's minimize algorithm assumes 64-bit precision. See Double (64 bit) Precision for more details. In brief, add these lines at the top of your notebook and restart the runtime, and you will get the expected results: from jax.config import config
config.update("jax_enable_x64", True) If you want to minimize a function in a way that uses JAX's auto-differentiation features and is compatible with JAX regardless of the dtype setting, you might take a look at from jax import jit
import jax.numpy as jnp
from jax.scipy.optimize import minimize
@jit
def jax_fun(x):
return (x[0] - 1)**2 + (x[1] - 2.5)**2
minimize(jax_fun, jnp.array([2.0, 0.0]), method="BFGS")
# OptimizeResults(x=DeviceArray([1. , 2.5], dtype=float32), success=DeviceArray(True, dtype=bool), status=DeviceArray(0, dtype=int32), fun=DeviceArray(0., dtype=float32), jac=DeviceArray([0., 0.], dtype=float32), hess_inv=DeviceArray([[0.93103445, 0.1724138 ],
# [0.1724138 , 0.56896555]], dtype=float32), nfev=DeviceArray(3, dtype=int32), njev=DeviceArray(3, dtype=int32), nit=DeviceArray(1, dtype=int32)) |
Beta Was this translation helpful? Give feedback.
-
There is a minus sign in the lambda that wraps the jax_fun call that doesn't seem to belong there. The two optimization problems are therefore not the same. |
Beta Was this translation helpful? Give feedback.
I think the issue is that JAX's default computation is 32-bit, while numpy's minimize algorithm assumes 64-bit precision. See Double (64 bit) Precision for more details.
In brief, add these lines at the top of your notebook and restart the runtime, and you will get the expected results:
If you want to minimize a function in a way that uses JAX's auto-differentiation features and is compatible with JAX regardless of the dtype setting, you might take a look at
jax.scipy.optimize.minimize
: