Hessian with jax.pure_callback
for custom JVP/custom VJP
#13282
-
Hi all, How can a Consider the following example: import jax
from jax import numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)
@jax.custom_jvp
def f(x, y):
def fun(params):
x, y = params
return jnp.sin(x) * y
r = jax.pure_callback(fun, jax.ShapeDtypeStruct((), jnp.float64), (x, y))
return r
@f.defjvp
def f_jvp(primals, tangents):
x, y = primals
x_dot, y_dot = tangents
primal_out = f(x, y)
def t_out(params):
x, y = params
return jnp.cos(x), jnp.sin(x)
shapes = (jax.ShapeDtypeStruct((), jnp.float64), jax.ShapeDtypeStruct((), jnp.float64))
cos_x, sin_x = jax.pure_callback(t_out, shapes, (x, y))
tangent_out = cos_x * x_dot * y + sin_x * y_dot
return primal_out, tangent_out
a = jnp.array(0.3)
b = jnp.array(0.3) Here, calling ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients. I suspect that a recursive approach should be used to get higher-order derivatives and custom JVP registrations here. Note that removing the Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
I think the issue here is that you use If you need to use Maybe something like this? @jax.custom_jvp
def t(x, y):
def fun(params):
return jnp.cos(x), jnp.sin(x)
shapes = (jax.ShapeDtypeStruct((), jnp.float64), jax.ShapeDtypeStruct((), jnp.float64))
return jax.pure_callback(fun, shapes, (x, y))
@t.defjvp
def f_jvp(primals, tangents):
x, y = primals
x_dot, y_dot = tangents
cos_x, sin_x = t(x, y)
return (cos_x, sin_x), ((-sin_x + cos_x) * x_dot, 0 * y_dot)
@jax.custom_jvp
def f(x, y):
def fun(params):
x, y = params
return jnp.sin(x) * y
r = jax.pure_callback(fun, jax.ShapeDtypeStruct((), jnp.float64), (x, y))
return r
@f.defjvp
def f_jvp(primals, tangents):
x, y = primals
x_dot, y_dot = tangents
primal_out = f(x, y)
cos_x, sin_x = t(x, y)
tangent_out = cos_x * x_dot * y + sin_x * y_dot
return primal_out, tangent_out |
Beta Was this translation helpful? Give feedback.
I think the issue here is that you use
pure_callback
within your custom JVP rule. Since the hessian is second-order, JAX computes the jvp of the jvp rule and then hits this error.If you need to use
pure_callback
within the JVP rule, I'd instead create a helper function for it that has its own JVP rule and call that instead.Maybe something like this?