Skip to content
Discussion options

You must be logged in to vote

I think the issue is that JAX's default computation is 32-bit, while numpy's minimize algorithm assumes 64-bit precision. See Double (64 bit) Precision for more details.

In brief, add these lines at the top of your notebook and restart the runtime, and you will get the expected results:

from jax.config import config
config.update("jax_enable_x64", True)

If you want to minimize a function in a way that uses JAX's auto-differentiation features and is compatible with JAX regardless of the dtype setting, you might take a look at jax.scipy.optimize.minimize:

from jax import jit
import jax.numpy as jnp
from jax.scipy.optimize import minimize

@jit
def jax_fun(x): 
    return (x[0] - 1)**2 + (x[1] 

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
1 reply
@JiahaoYao
Comment options

Answer selected by JiahaoYao
Comment options

You must be logged in to vote
1 reply
@JiahaoYao
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants