Questions about vmap on GPU and execution time scaling w.r.t to input size when only using the GPU #19618
-
Hello everyone ! I've been messing around with Jax (in reinforcement learning) for around 6 months now. My knowledge about the specifics of Jax is still quite limited when it comes to how it effectively runs things behind the hood. My questions revolves around how does Jax handles computations under map on GPU when it comes to increasing the number of vectorized operations? Here is my use case to clarify my question : I have been running experiments on GPU (GeForce GTX TITAN X) and it so happens that that when I change the size of I also observe that the GPU time spent accessing memory increases with input size (0% for 1 seed, ~3% for 10 seeds and ~30% for 100 seeds). So it could be that the code does not exclusively runs on GPU and that the dialogue between CPU and GPU that arises with high input size would slow down execution. Would there by a way to find out where in my code this could come from? Here is the link to the repo if you have the time to dive into (as it is a bit complex and hard to reduce to a meaningful example in my case) : https://github.com/YannBerthelot/jaxppo/tree/rnn Thanks in advance for any help or tips on how to improve performance on my approach ! Do not hesitate if you need more details. (I believe my question is linked to this other one, #19103, however I believe that in my case the whole program is run on GPU and not just a part of it, so there's plenty of room for parallel optimization) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Hi - so I think the main misunderstanding here is that As for why large input sizes slow down your computation... well, the size of the computation in vmap scales linearly with the size of the input. From the times you quote, if I'm understanding correctly it looks like the actual wall-time scaling is very sub-linear in the number of batches, meaning that for your program, the compiler is making very good use of the hardware. Does that help answer your question? |
Beta Was this translation helpful? Give feedback.
Hi - so I think the main misunderstanding here is that
vmap
doesn't explicitly have anything to do with parallelism: it just converts unbatched instructions to batched instructions. So a vmapped vector product is just a matrix project of the batched input. A vmapped sum is just a sum along one axis of the batched input.As for why large input sizes slow down your computation... well, the size of the computation in vmap scales linearly with the size of the input. From the times you quote, if I'm understanding correctly it looks like the actual wall-time scaling is very sub-linear in the number of batches, meaning that for your program, the compiler is making very good use of the hardware.
D…