Skip to content
Discussion options

You must be logged in to vote

No, there's no syntax to compute gradients with respect to a single element in a pytree, though it's been discussed (see #3875, #10614 and related discussions).

The easiest way to proceed is typically to take the gradient with respect to the whole pytree, then access the particular element you're interested and continue your computation. If this is wrapped in jit, the compiler will automatically elide any unnecessary computations. You can confirm this by using Ahead of time compilation to print the compiled HLO.

For example:

@jax.jit
def func(x):
  return g(x).a

print(func.lower(x).compile().as_text())
HloModule jit_func, entry_computation_layout={(f32[])->f32[5]{0}}, allow_spmd_sharding…

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@alonfnt
Comment options

@jakevdp
Comment options

Answer selected by alonfnt
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants