Skip to content
Discussion options

You must be logged in to vote

With the help of a colleague, we found the problem. It is because my slurm job was allocated with 1 CPU per task (which was the default of my cluster) and -- as recommend by jax -- with 1 task per GPU. I had subconsciously assumed that jax only uses the CPUs for compiling code and therefore the number of CPUs is completely irrelevant. However, it turns out that CPUs do actually a lot of work during communication! Simply by adding
#SBATCH --cpus-per-task=8
or
#SBATCH --exclusive
to my slurm job script performance jumps up dramatically -- making the performance consistent between ragged and fixed all-to-all and also improving the fixed all-to-all notably:

For the low-data limit this is app…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by jstuecker
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant