Slow gradient calculation for convnets when loss include both jvp and vjp #15584
Unanswered
sorenhauberg
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
We are working on some code that requires us to compute the gradient of a loss, including$J J^T \epsilon$ , where $J$ is the Jacobian of a neural network, and $\epsilon$ can be any vector. The actual loss is a bit more complex, but the above demonstrates the issue we face.
Jax makes it wonderfully easy to evaluate this loss:
For neural networks with linear layers, it is quite fast to evaluate this loss as well as its gradient. However, when using convolutional networks, we see drastic slowdowns. In particular, I get times like
where the three ConvNets have 1, 2, and 3 convolution layers, respectively. Note that if I only want to evaluate the loss (not its gradient), then all models are fast.
I'm a bit stuck on how to proceed and would appreciate any pointers on how to get around this slowdown. Is this a fundamental issue with convolutions, is it a matter of implementation, or is it a bug in Jax that produces a slow code path? Any suggestions as to how I can make the code faster?
A complete MWE is available on Google Colab here.
Beta Was this translation helpful? Give feedback.
All reactions