Skip to content
Discussion options

You must be logged in to vote

In JAX, I believe the best way to do this would be to use f.defjvps, and pass a function per derivative rule. For example:

import jax
import jax.numpy as jnp

@jax.custom_jvp
def f(x, y, z):
    return x * jnp.sin(y) * jnp.cos(z)

f.defjvps(lambda x_dot, primal_out, x, y, z: x_dot * jnp.sin(y) * jnp.cos(z),
          lambda y_dot, primal_out, x, y, z: y_dot * x * jnp.cos(y) * jnp.cos(z),
          lambda z_dot, primal_out, x, y, z: -z_dot * x * jnp.sin(y) * jnp.sin(z))

print(jax.grad(f, argnums=0)(1.0, 1.0, 1.0))
# 0.45464867

You can see some examples of this in the Custom Derivative Rules docs.

Replies: 1 comment 5 replies

Comment options

You must be logged in to vote
5 replies
@sashaDoubov
Comment options

@jakevdp
Comment options

@jakevdp
Comment options

@jakevdp
Comment options

@sashaDoubov
Comment options

Answer selected by sashaDoubov
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants