Skip to content
Discussion options

You must be logged in to vote

Thanks for the question!

Whether you use custom_jvp or custom_vjp need not affect performance at all. In fact you can "spell" the exact same computation either way in most cases.

I say "in most cases" as a hedge because there are cases where they might differ. Here is a non-exhaustive list which comes to mind:

  • With custom_vjp you can more directly control what the "residuals" are, i.e. what values are saved from the forward pass for use on the backward pass. With custom_jvp you're leaving the residuals up to JAX's partial evaluation machinery. Those can be controlled too (e.g. with jax.remat/jax.checkpoint) but it's not "built in" in the way it is with custom_vjp. Choice of residuals can…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@antalszava
Comment options

Answer selected by antalszava
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants