Replies: 1 comment 1 reply
-
Hi Roeland, I was looking through here for mpi content and saw this old question, sorry if you already got an answer. In my experience, setting Many MPI job launchers will set per-process environment variables that enable you do this things, try In my own code, I discover the local rank with mpi4py and then use that local rank to target the device I want specifically, but not by hand. If you need a hint along those lines, here's how I get the local rank: I pass that rank to this function to turn it into a jax device: And then I use that device in my tensor creation calls: https://github.com/Nuclear-Physics-with-Machine-Learning/JAX_QMC_Public/blob/main/bin/sr.py#L179-L180 |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hey,
I am working with some code that uses MPI + JAX and I'm having trouble with setting the visible devices for JAX. I know that you can do Multi-GPU setups with Jax without MPI, but the code-base I'm using forces me to use MPI.
Setup:
Python 3.9.6
CUDA=11.4
NetKet==3.8
jax==0.4.9
jaxlib==0.4.7+cuda11.cudnn82
My MPI configuration is CUDA-enabled, although I don't think this an MPI problem.
For this MWE, I am working with 2 GPU's on a single node. There is an old isse that discusses how to place devices manually (#2965).
test.py
Running
mpirun -np 2 python test.py
then givesSo both processes see both GPUs. While this is fine, I cannot set the visible devices manually to force process 0 to use GPU:0 by default, and process 1 to use GPU:1.
Note that placing things manually works as expected, but the larger code base that I'm using expects the placement to be done automatically (so each process has to only see one device).
Could someone help me understand what is going on here?
Best,
Roeland
Beta Was this translation helpful? Give feedback.
All reactions