Replies: 2 comments
-
Just started using |
Beta Was this translation helpful? Give feedback.
0 replies
-
We don't have a plan to add them, but we could if needed. It's not so difficult to implement them in terms of import mlx.core as mx
def jacrev(f):
def jacfn(x):
# Needed for the size of the output
y = f(x)
def vjpfn(cotan):
return mx.vjp(f, (x,), (cotan,))[1][0]
return mx.vmap(vjpfn, in_axes=0)(mx.eye(len(y)))
return jacfn
def jacfwd(f):
def jacfn(x):
def jvpfn(tan):
return mx.jvp(f, (x,), (tan,))[1][0]
return mx.vmap(jvpfn, in_axes=0)(mx.eye(len(x)))
return jacfn
def hessian(f):
def hessfn(x):
def hvp(tan):
return mx.jvp(mx.grad(f), (x,), (tan,))[1][0]
return mx.vmap(hvp, in_axes=0)(mx.eye(len(x)))
return hessfn
print(jacrev(mx.sin)(mx.array([1.0, 2.0, 3.0])))
print(jacfwd(mx.sin)(mx.array([1.0, 2.0, 3.0])))
def fun(x):
return mx.sin(x).sum()
print(hessian(fun)(mx.array([1.0, 2.0, 3.0]))) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
In JAX, there are
jacfwd
,jacrev,
andhessian
functions for transforming the objective function into functions that compute first-order or second-order derivatives. I'm curious to know if MLX has plans to incorporate these three functions in the future.Beta Was this translation helpful? Give feedback.
All reactions