Skip to content

Commit 3074823

Browse files
authored
Bugfix torchrun hpc (#23)
* Fix a bug in how the sys.argv is being passed and constructued to runpy.run_path. Reduced logging in torchrun_hpc_trampoline. * Added a __init__.py file for lc systems.
1 parent 06150fc commit 3074823

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

hpc_launcher/systems/lc/__init__.py

Whitespace-only changes.

hpc_launcher/torch/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,4 @@
2929
if torch.cuda.is_available():
3030
fraction_max_gpu_mem = float(os.getenv("TORCHRUN_HPC_MAX_GPU_MEM"))
3131
if fraction_max_gpu_mem != 1.0:
32-
print(f"Setting the max GPU memory fraction to {fraction_max_gpu_mem}")
3332
torch.cuda.set_per_process_memory_fraction(fraction_max_gpu_mem)

hpc_launcher/torch/torchrun_hpc_trampoline.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ def main():
3333
scheduler.get_parallel_configuration()
3434
)
3535

36+
# Report if the memory size was set
37+
if torch.cuda.is_available():
38+
fraction_max_gpu_mem = float(os.getenv("TORCHRUN_HPC_MAX_GPU_MEM"))
39+
if fraction_max_gpu_mem != 1.0 and rank == 0:
40+
print(f"[Rank {rank} of {world_size}] TORCHRUN-HPC set the max GPU memory fraction to {fraction_max_gpu_mem}")
41+
3642
torch_dist_initialized = dist.is_initialized()
3743
if world_size > 1:
3844
rdv_protocol = os.getenv("TORCHRUN_HPC_RDV_PROTOCOL")
@@ -52,16 +58,19 @@ def main():
5258
)
5359

5460
if not torch_dist_initialized:
55-
print(f"Initializing distributed PyTorch using protocol: {rdv_protocol}")
61+
if rank == 0:
62+
print(f"[Rank {rank} of {world_size}]: Initializing distributed PyTorch using protocol: {rdv_protocol}")
5663
# TODO(later): Fix how we handle CUDA visible devices and MPI bind
5764
dist.init_process_group(
5865
"nccl", init_method=rdv_protocol, world_size=world_size, rank=rank
5966
)
6067

61-
if rdv_protocol == "mpi://":
62-
print("MPI Version: {}".format(MPI.Get_version()))
63-
print("MPI Implementation: {}".format(MPI.Get_library_version()))
68+
if rdv_protocol == "mpi://" and rank == 0:
69+
print("[Rank {} of {}]: MPI Version: {}".format(rank, world_size, MPI.Get_version()))
70+
print("[Rank {} of {}]: MPI Implementation: {}".format(rank, world_size, MPI.Get_library_version()))
6471

72+
# Note that run_path will prepend the args[0] back onto the sys.argv so it needs to be stripped off first
73+
sys.argv = sys.argv[1:]
6574
# Run underlying script
6675
runpy.run_path(args[0], run_name="__main__")
6776

0 commit comments

Comments
 (0)