Skip to content
Discussion options

You must be logged in to vote

No, there's no mechanism to do this aside from defining a function in terms of the pytree elements that you're interested in differentiating against. Options for this have been proposed (e.g. #3875 and #10614) but nothing has yet been implemented. Your best bet is probably to compute the gradient with respect to flattened arguments; it's a bit messy, but the workaround might look something like this:

from jax import tree_util, grad
import jax.numpy as jnp
from typing import NamedTuple

class Data(NamedTuple):
  x: jnp.ndarray
  i: int

data = Data(jnp.arange(5.0), 1)

# We want to compute the gradient of f with respect to x... how?
def f(data):
  return data.x[data.i]

# Wrap it with a fl…

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Answer selected by conorheins
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants