Is there a way to disable forward evaluation while using VJP? #7361
Unanswered
aonurdasdemir
asked this question in
Q&A
Replies: 1 comment 1 reply
-
JAX needs to evaluate the forward pass in order to calculate intermediate quantitites that are used in the VJP calculation. So there's not really any avoiding it. Probably your best option is to use |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I use VJP frequently in my project. It runs the function that is subject to Jacobian computation and returns a primals_out together with the callable vjp function. For example, custom VJP definition in JAX documentation is given like this:
In this example, we see that evaluation of the forward function is required when using VJP. This is also the case when using regular VJP instead of a custom defined one. However, when the evaluation of function costs highly and since I have already run that function somewhere in my code, I don't want VJP to evaluate that function one more time.
So, is there any way to indicate that a function will not be evaluated when computing its VJP?
Beta Was this translation helpful? Give feedback.
All reactions