Skip to content
Discussion options

You must be logged in to vote

I think the issue here is that you use pure_callback within your custom JVP rule. Since the hessian is second-order, JAX computes the jvp of the jvp rule and then hits this error.

If you need to use pure_callback within the JVP rule, I'd instead create a helper function for it that has its own JVP rule and call that instead.

Maybe something like this?

@jax.custom_jvp
def t(x, y):
  def fun(params):
      return jnp.cos(x), jnp.sin(x)
  shapes = (jax.ShapeDtypeStruct((), jnp.float64), jax.ShapeDtypeStruct((), jnp.float64))
  return jax.pure_callback(fun, shapes, (x, y))

@t.defjvp
def f_jvp(primals, tangents):
  x, y = primals
  x_dot, y_dot = tangents
  cos_x, sin_x = t(x, y)
  return (cos_x

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@antalszava
Comment options

Answer selected by antalszava
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants