Skip to content

Commit 4c14b52

Browse files
authored
Added the device_id initialization to the init_process_group call in (#56)
the torchrun-hpc trampoline.
1 parent b44575d commit 4c14b52

File tree

1 file changed

+22
-19
lines changed

1 file changed

+22
-19
lines changed

hpc_launcher/torch/torchrun_hpc_trampoline.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,36 @@ def main():
3535

3636
# Check on the backend and report if the memory size was set
3737
backend = None
38+
device = None
3839
if torch.cuda.is_available():
3940
backend = "nccl"
41+
device = "cuda"
4042
fraction_max_gpu_mem = float(os.getenv("HPC_LAUNCHER_MAX_GPU_MEM", 1.0))
4143
if fraction_max_gpu_mem != 1.0 and rank == 0:
4244
print(
4345
f"[Rank {rank} of {world_size}] TORCHRUN-HPC set the max GPU memory fraction to {fraction_max_gpu_mem}"
4446
)
4547
else:
4648
backend = "gloo"
49+
device="cpu"
50+
51+
# Standard operating mode assumes that there is one rank per GPU
52+
# Check to see how many GPUS are actually available to this rank
53+
avail_gpus = 0
54+
gpus = []
55+
for e in ["CUDA_VISIBLE_DEVICES", "ROCR_VISIBLE_DEVICES", "HIP_VISIBLE_DEVICES"]:
56+
if os.getenv(e):
57+
gpus = os.getenv(e)
58+
break
59+
if gpus:
60+
avail_gpus = gpus.split(",")
61+
62+
# Round-robin assign the visibile GPUs
63+
if avail_gpus:
64+
local_device_id = local_rank % len(avail_gpus)
65+
else:
66+
local_device_id = local_rank
67+
os.environ["LOCAL_RANK"] = f"{local_device_id}"
4768

4869
torch_dist_initialized = dist.is_initialized()
4970
rdv_protocol = os.getenv("TORCHRUN_HPC_RDV_PROTOCOL")
@@ -77,7 +98,7 @@ def main():
7798
)
7899
# TODO(later): Fix how we handle CUDA visible devices and MPI bind
79100
dist.init_process_group(
80-
backend, init_method=rdv_protocol, world_size=world_size, rank=rank
101+
backend, init_method=rdv_protocol, world_size=world_size, rank=rank, device_id=torch.device(device, local_device_id)
81102
)
82103

83104
if rdv_protocol == "mpi://" and rank == 0:
@@ -108,24 +129,6 @@ def main():
108129
# If the mpi rendezvous protocol is set, this should be necessary but some packages still look for it
109130
os.environ["MASTER_ADDR"] = "23456"
110131

111-
# Standard operating mode assumes that there is one rank per GPU
112-
# Check to see how many GPUS are actually available to this rank
113-
avail_gpus = 0
114-
gpus = []
115-
for e in ["CUDA_VISIBLE_DEVICES", "ROCR_VISIBLE_DEVICES", "HIP_VISIBLE_DEVICES"]:
116-
if os.getenv(e):
117-
gpus = os.getenv(e)
118-
break
119-
if gpus:
120-
avail_gpus = gpus.split(",")
121-
122-
# Round-robin assign the visibile GPUs
123-
if avail_gpus:
124-
local_gpu_id = local_rank % len(avail_gpus)
125-
else:
126-
local_gpu_id = local_rank
127-
os.environ["LOCAL_RANK"] = f"{local_gpu_id}"
128-
129132
# Note that run_path will prepend the args[0] back onto the sys.argv so it needs to be stripped off first
130133
sys.argv = sys.argv[1:]
131134
# Run underlying script

0 commit comments

Comments
 (0)