What does the performance of lax.pmean depend on? #19651
Replies: 1 comment
-
All issues have been resolved, as I used very large matrices for testing and NCCL's ring-reduce couldn't handle such large matrices. I don't know how JAX deals with this issue, but it caused a significant performance degradation. Typically, gradients consist of a set of smaller matrices, so this kind of performance drop won't affect actual usage. I distributed 4 128M-sized matrices across 4 GPUs and then performed pmean, observing significant acceleration from NVLink. server 0:
server 1:
|
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I use
lax.pmean
to implement DDP. Based on my understanding, lax.pmean can calculate the average of gradients across all GPUs. To improve the efficiency of parallel training, efforts should be made to speed up lax.pmean.I tested the performance of lax.pmean on two servers, one is an AWS G5 instance with 4 A10G, and the other is a cluster composed of 4 V100 SXM2 connected by NVLink.
To my surprise, on the 4 A10G instances without NVLink, the performance of lax.pmean is higher. I wonder why this is and why NVLink did not work as expected?
The test code is as follows:
The software environment is as follows (consistent across two devices):
The hardware of the 2 servers is as follows:
server 0:
server 1:
The result is as follows:
server 0:
server 1:
And I checked NCCL, confirming that jax used nvlink to run pmean. Additionally, I attempted to disable the NVLink on server 0, and the result is as follows:
server 0:
I would like to know why a 4xV100 cluster does not have an advantage in parallel training? How should I improve it? Thank you very much for any answers!
Beta Was this translation helpful? Give feedback.
All reactions