Skip to content
Discussion options

You must be logged in to vote

There's no way to make vmap(grad(fv)) work for this function, because grad requires a scalar-output function, and your fv function is a vector output that does not correspond to any vector input. If you're interested on the gradient's effect on a single output, you could do something like this:

jax.grad(lambda x: fv(x)[0])(x)

If you're interested in computing the gradients for all outputs at once, you could do something like this:

jax.jacrev(fv)(x)

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by ipcamit
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants