custom_vjp
with variable nondiff_argnums
#13121
-
Hi, I'm porting some pytorch code over to jax, and am having some trouble understanding how to implement similar functionality in jax. In the pytorch code, the authors have a custom forward and backwards function. In the custom backwards function, they optionally return the gradient wrt a variable depending on whether the input requires a gradient, otherwise they return
However, in jax, from reading the docs, it seems to me that I would have to either:
Is there a way to implement similar logic in jax? Or can this be avoided somehow? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 5 replies
-
In JAX, I believe the best way to do this would be to use import jax
import jax.numpy as jnp
@jax.custom_jvp
def f(x, y, z):
return x * jnp.sin(y) * jnp.cos(z)
f.defjvps(lambda x_dot, primal_out, x, y, z: x_dot * jnp.sin(y) * jnp.cos(z),
lambda y_dot, primal_out, x, y, z: y_dot * x * jnp.cos(y) * jnp.cos(z),
lambda z_dot, primal_out, x, y, z: -z_dot * x * jnp.sin(y) * jnp.sin(z))
print(jax.grad(f, argnums=0)(1.0, 1.0, 1.0))
# 0.45464867 You can see some examples of this in the Custom Derivative Rules docs. |
Beta Was this translation helpful? Give feedback.
In JAX, I believe the best way to do this would be to use
f.defjvps
, and pass a function per derivative rule. For example:You can see some examples of this in the Custom Derivative Rules docs.