Evaluate jax.grad(f)
with respect to particular leaves in a PyTree
#12765
-
This question seems very similar to this one but I wanted to write it again here to both increase visibility of the issue, as well as to make the question both more concise/clear and also more general. Consider a function
Is there a quick/JAX-ian way to compute My naïve approach would be to map Thanks in advance! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
No, there's no mechanism to do this aside from defining a function in terms of the pytree elements that you're interested in differentiating against. Options for this have been proposed (e.g. #3875 and #10614) but nothing has yet been implemented. Your best bet is probably to compute the gradient with respect to flattened arguments; it's a bit messy, but the workaround might look something like this: from jax import tree_util, grad
import jax.numpy as jnp
from typing import NamedTuple
class Data(NamedTuple):
x: jnp.ndarray
i: int
data = Data(jnp.arange(5.0), 1)
# We want to compute the gradient of f with respect to x... how?
def f(data):
return data.x[data.i]
# Wrap it with a flattened version of the arguments...
def f_wrap(x, i, tree):
data = tree_util.tree_unflatten(tree, (x, i))
return f(data)
args, tree = tree_util.tree_flatten(data)
result = grad(f_wrap)(*args, tree=tree) |
Beta Was this translation helpful? Give feedback.
-
Thanks very much for the quick answer @jakevdp -- much appreciated & good to know! |
Beta Was this translation helpful? Give feedback.
No, there's no mechanism to do this aside from defining a function in terms of the pytree elements that you're interested in differentiating against. Options for this have been proposed (e.g. #3875 and #10614) but nothing has yet been implemented. Your best bet is probably to compute the gradient with respect to flattened arguments; it's a bit messy, but the workaround might look something like this: