Skip to content
Discussion options

You must be logged in to vote

You can use jax.jacrev and jax.jacfwd.
jax.jacrev can be considered as vmapped vjp. It is basically the same as your code, but do it for all outputs in parallel.
jax.jacfwd can be considered as vmapped jvp.
I recommend you to read #47 (comment).
And https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html.

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@DoTulip
Comment options

@YouJiacheng
Comment options

@DoTulip
Comment options

Answer selected by DoTulip
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants