-
|
I am writing a multi-GPU particle simulation code with jax + CUDA-FFI. Particles are ordered along a space filling curve. So it can happen that any GPU needs data from any other GPU, but in general most communication will happen with closer GPUs and little communication with farther away GPUs.
However, today I have written a little benchmark to understand its performance... And... I am shocked... Below, you can find a plot of the time that it takes to finish a simple jitted+shard-mapped function with a perfectly balanced all-to-all communication for 64 GPUs (16 nodes, 4 GPUs per node). The x-axis shows the amount of data that is created on every GPU for the communication. (So e.g. at 2GB each GPU sends to each other GPU around 32MB.) The lines compare the performance of jax.lax.ragged_all_to_all versus jax.lax.all_to_all. Obviously jax.lax.all_to_all is more optimized towards this scenario, so it is understandable that jax.lax.ragged_all_to_all would be notably slower. However, what puzzles me is how large this gap is. Consider that the communication time for basically empty messages is about 10 times worse than that of a fixed size jax.lax.all_to_all. An empty ragged all-to-all communication is about as slow as a fixed all-to-all where each node sends ~8MB to each other node...
Here is the script that produces this plot ragged_all2all.py (was run on a slurm system with 1 process per GPU). You can find the system specs here under compute nodes / booster partition. I am using CUDA 12.9 and jax 0.8.2. I wonder whether someone has any suggestions for getting more out of the ragged_all_to_all communication. Also I wonder how to explain the extreme difference between jax.lax.all_to_all and jax.lax.ragged_all_to_all. I would have expected that they would roughly need to do the same thing here. Understanding this might help me make the right decisions moving forward. On another note... does anyone know whether it is an option to implement custom communication kernels inside jax's FFI with NCCL? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Beta Was this translation helpful? Give feedback.


With the help of a colleague, we found the problem. It is because my slurm job was allocated with 1 CPU per task (which was the default of my cluster) and -- as recommend by jax -- with 1 task per GPU. I had subconsciously assumed that jax only uses the CPUs for compiling code and therefore the number of CPUs is completely irrelevant. However, it turns out that CPUs do actually a lot of work during communication! Simply by adding
#SBATCH --cpus-per-task=8or
#SBATCH --exclusiveto my slurm job script performance jumps up dramatically -- making the performance consistent between ragged and fixed all-to-all and also improving the fixed all-to-all notably:
For the low-data limit this is app…