Reusing computation results across multiple gradients in JAX #16464
-
Hello,
Is there a way to compute the gradients of both values together, while reusing the result of the expensive operation? I understand that JAX avoids side effects and mutable state by design, which makes caching results across multiple function calls tricky. Is there any workaround or recommended way to handle this type of scenario? Or should I perhaps rely on JIT for common subexpression elimination as outlined in Discussion #16356? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
If you want to put stuff together manually in this case you could use import jax
from jax import grad, jit
import jax.numpy as jnp
# Dummy expensive operation
def expensive_function(x):
return jnp.square(x)
# Objective
def objective_fn(y):
return jnp.sum(y)
# Constraint
def constraint_fn(y):
return jnp.sum(y * 2)··
# Inputs
x = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=jnp.float32)
# Gradients
y, vjp_fn = jax.vjp(expensive_function, x)
objective_ygrad = grad(objective_fn)(y)
constraint_ygrad = grad(constraint_fn)(y)
objective_grad = vjp_fn(objective_ygrad)[0]
constraint_grad = vjp_fn(constraint_ygrad)[0] I'm not 100% sure this is correct since I'm not a JAX expert, but I believe it's basically doing what you want. The Autodiff Cookbook is a great resource which explains the various autodiff options in JAX. |
Beta Was this translation helpful? Give feedback.
If you want to put stuff together manually in this case you could use
jax.vjp
, but you'd need to refactorobjective_fn
andconstraint_fn
to takey
as an argument. Here's what it might look like: