Slow compilation and OOM gradient computation #8365
Replies: 1 comment 9 replies
-
This looks like an awesome project! I think that this is what your
Q1: Yes, you can jit the Q2. The RAM on a colab seem sufficient to compute Hope that helps! P.S. When I ran the |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi all,
First of all, thank you so much for the awesome project. I was inspired by both Brax a talk I saw @mattjj give recently, and thought that I could use Jax for my soft structure differentiable simulation needs. I tried to port the diffmpm module from the difftaichi paper to Jax. Linked here is my attempt, where I've commented in two questions at the bottom (Line 294/298). I want to note that this is still not working perfectly - I know it has correctness bugs, but my questions here are specifically about performance.
I am extremely happy with the forward evaluation speed. It is approximately the speed I was hoping it would perform at without too many software engineering tricks. However, I am finding two aspects, compilation speed, and gradient memory usage, to be a bit problematic.
I created a simulation function by jit-ing a few subroutines and executing them in a loop. This loop is executed n times, after which a loss value is returend.
Question 1: Is there any way to then jit the entire simulation? Is there even a point to doing this? I was unsure if a) jits can be nested into one another, b) if it's good practice to do so, and c) if it is allowed, why does it take so long to execute? I have given up on the compilation as it runs for minutes without terminating, which is not particularly useful.
Question 2: One of the reasons jax is so exciting to me is because it is differentiable and easily parallelizable, and higher-order gradients appear to be simple to compute and accelerated. Unfortunately, attempting to compute grad on my large function leads to out of memory issues on my 8GB GPU. Even including intermediates, I'm genuinely not sure that this should lead to an out of memory issue. If my back-of-the-envelope math is right, each state variable is less than 100kB, so even considering all of them over 2,000 steps, this shouldn't take more than a few GB of memory - and that's without any compiler optimizations. Right now, I can only handle a few hundred steps. I could, of course, compute the grad of each simulation step and perform backpropagation manually, but this seems to defeat part of the purpose of Jax, and would make it impossible to efficiently compute higher-order derivatives. Is there a reason my simulation is using so much memory? If this is the expected memory consumption, since each loop of my simulation is identical in structure, is there anything I can do to simplify the compilation procedure?
If either issue is a bug as opposed to user error, please let me know and I will file a bug.
Beta Was this translation helpful? Give feedback.
All reactions