Skip to content
Discussion options

You must be logged in to vote

Hi - thanks for the question! In JAX, forward-mode autodiff is handled by jvp (Jacobian vector product), while reverse-mode autodiff is handled by vjp (vector Jacobian product).

When you define custom_vjp, you are defining how reverse-mode autodiff should behave, but you have not defined how forward-mode autodiff should behave.

To fix this, you can either stick to reverse-mode autodiff transformations (i.e. use vjp, jacrev, jacobian, or grad, all of which use reverse-mode autodiff) or you can define a custom_jvp for your function, which defines how it should behave with forward-mode autodiff.

For more details, see Custom derivative rules for Python code.

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@miaoL1
Comment options

Answer selected by miaoL1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants