Why do later iterations of my code run faster? #16247
-
I have a training loop that I run several times in a row as part of an active learning strategy (every iteration, I optimize some vector which tells me where to sample next in the dataset) using optax. However, I'm finding that the training loop runs much faster in iterations after the first one. I'm wondering why this is (seems like some kind of memory allocation thing), and if I can somehow exploit it to make even the first iteration run faster. The training loop is pretty simple:
Later iterations of this code run at more than twice the speed of the first iteration, and I'm designing a real-time experiment so speed is key. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
It's because the JIT is compiling your |
Beta Was this translation helpful? Give feedback.
It's because the JIT is compiling your
update()
code, so it may not be easy to make it faster unless there's some low hanging fruit you can use the lax loop primitives with. Incidentally you may see a speedup if you move the optimizer.update call inside the JIT-ed region of code as well. It might also (not really sure) reduce peak memory usage since the non-JITTed update call will need to have the pre and post update versions of your parameters in your accelerator memory but the JIT should figure out that the update can be done in place.