jax.jacrev
performance: custom_vjp
vs custom_jvp
#12859
-
Hi everyone, assuming we want to get the Jacobian of a differentiable function
I've come across the note on Decomposing reverse-mode automatic differentiation - I'd be keen to learn more about how the decomposition (if at all) affects performance. Does the performance difference between 1. and 2. (assuming it's significant) depend on:
Let me know if it would be helpful to share more about the |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Thanks for the question! Whether you use I say "in most cases" as a hedge because there are cases where they might differ. Here is a non-exhaustive list which comes to mind:
I wouldn't expect the size of WDYT? |
Beta Was this translation helpful? Give feedback.
Thanks for the question!
Whether you use
custom_jvp
orcustom_vjp
need not affect performance at all. In fact you can "spell" the exact same computation either way in most cases.I say "in most cases" as a hedge because there are cases where they might differ. Here is a non-exhaustive list which comes to mind:
custom_vjp
you can more directly control what the "residuals" are, i.e. what values are saved from the forward pass for use on the backward pass. Withcustom_jvp
you're leaving the residuals up to JAX's partial evaluation machinery. Those can be controlled too (e.g. withjax.remat
/jax.checkpoint
) but it's not "built in" in the way it is withcustom_vjp
. Choice of residuals can…