Deadlock Running Simple All Gather XLA Graph Across Two Local GPU Devices #16315
Unanswered
xanderdunn
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I have a simple XLA HLO graph with a single all gather across two devices:
The graph for rank 0: rust_hlo_rank_0.test_all_gather_dim0.pb.zip
The graph for rank 1: rust_hlo_rank_1.test_all_gather_dim0.pb.zip
These graphs are identical.
I am attempting to run this graph distributed across processes on the same GPU host with this script. run_xla_cpu_gpu.py:
I followed the documentation here and here in writing this script.
So now I run two processes, one for each of the two ranks:
However, both processes hang indefinitely:
I'm guessing that I've misconfigured the distributed setup somehow, or I'm misusing the jax API here in some way? I'm running on a machine with 8 A100s, and NCCL works fine to communicate between devices in other applications. My guess is that the graph execution is hanging because it can't find / can't communicate with the other device. Any ideas what might be wrong here?
I notice the existence of both
execute_sharded*
functions onExecutable
, as well as aDistributedRunTimeClient
inxla_extension/__init__.py
. Is one of these appropriate here?Note that it does run and return results on CPU, but the result is incorrect. Maybe comms aren't actually supported on the CPU backend?
On a correct all gather output I would expect the output to be all
1.0
, since both ranks have all1.0
as input.Thanks!
I'm running jax 0.4.11 on CUDA 12 and cudnn 8.9.
Beta Was this translation helpful? Give feedback.
All reactions