Skip to content

Commit b33b7d3

Browse files
authored
Enable module execution with '-m' (#63)
1 parent 4a5c36e commit b33b7d3

File tree

8 files changed

+149
-32
lines changed

8 files changed

+149
-32
lines changed

hpc_launcher/cli/torchrun_hpc.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434

3535
def main():
3636
parser = argparse.ArgumentParser(
37-
description="A wrapper script that launches and runs distributed PyTorch on HPC systems."
37+
description=
38+
"A wrapper script that launches and runs distributed PyTorch on HPC systems."
3839
)
3940
common_args.setup_arguments(parser)
4041
parser.add_argument(
@@ -57,11 +58,20 @@ def main():
5758
"--unswap-rocr-hip-vis-dev",
5859
action="store_true",
5960
default=False,
60-
help="Undo moving ROCR_VISIBLE_DEVICES into the HIP_VISIBLE_DEVICES env variable. "
61+
help=
62+
"Undo moving ROCR_VISIBLE_DEVICES into the HIP_VISIBLE_DEVICES env variable. "
6163
"In PyTorch codes HIP_VISIBLE_DEVICES is most similar to CUDA_VISIBLE_DEVICES. "
6264
"Ensureing that HIP vs ROCR can improve behavior of HF Accelerate and TorchTitan.",
6365
)
6466

67+
parser.add_argument(
68+
"-m",
69+
"--module",
70+
action="store_true",
71+
default=False,
72+
help="If specified, the command will be interpreted as "
73+
"a Python module (similar to `python -m module ...`).")
74+
6575
# Grab the rest of the command line to launch
6676
# torchrun-hpc does not support running with a pre-generated batch script file
6777
parser.add_argument("command", help="Command to be executed")
@@ -86,7 +96,9 @@ def main():
8696
if args.job_comm_protocol:
8797
optimize_comm_protocol = args.job_comm_protocol
8898
if optimize_comm_protocol.upper() == "MPI":
89-
logger.warning(f"Using MPI as the primary communication protocol for PyTorch requires additional support")
99+
logger.warning(
100+
f"Using MPI as the primary communication protocol for PyTorch requires additional support"
101+
)
90102
else:
91103
system.job_comm_protocol = "*CCL"
92104
# Pick batch scheduler
@@ -122,21 +134,22 @@ def main():
122134
)
123135
exit(1)
124136

125-
if args.bg and args.launch_dir is None: # or args.batch_script
137+
if args.bg and args.launch_dir is None: # or args.batch_script
126138
# If running a batch job with no launch directory argument,
127139
# run in the generated timestamped directory
128140
args.launch_dir = ""
129141
if args.launch_dir is None and not args.bg:
130142
args.launch_dir = ""
131-
logger.info(f"torchrun-hpc needs to run jobs from a launch directory -- automatically setting the -l (--launch-dir) CLI argument")
143+
logger.info(
144+
f"torchrun-hpc needs to run jobs from a launch directory -- automatically setting the -l (--launch-dir) CLI argument"
145+
)
132146

133147
_, folder_name = scheduler.create_launch_folder_name(
134-
args.command, "torchrun_hpc", args.launch_dir
135-
)
148+
args.command, "torchrun_hpc", args.launch_dir)
136149

137-
script_file = scheduler.create_launch_folder(
138-
folder_name, not args.bg, args.output_script, args.dry_run
139-
)
150+
script_file = scheduler.create_launch_folder(folder_name, not args.bg,
151+
args.output_script,
152+
args.dry_run)
140153

141154
trampoline_file = "torchrun_hpc_trampoline.py"
142155

@@ -152,8 +165,11 @@ def main():
152165
launch_args = [
153166
"-u",
154167
f"{os.path.abspath(folder_name)}/{trampoline_file}",
155-
os.path.abspath(args.command),
156168
]
169+
if args.module:
170+
launch_args += ["-m", args.command]
171+
else:
172+
launch_args.append(os.path.abspath(args.command))
157173
launch_args += args.args
158174

159175
logger.info(f"Running job in directory: {folder_name}")

hpc_launcher/torch/torchrun_hpc_trampoline.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,33 +26,41 @@
2626
def main():
2727
# Strip off the name of this script and pass the rest to runpy
2828
args = sys.argv[1:]
29+
if args[0] == "-m":
30+
is_module = True
31+
args = args[1:]
32+
else:
33+
is_module = False
2934

3035
scheduler_type = os.getenv("TORCHRUN_HPC_SCHEDULER")
3136
scheduler = get_schedulers()[scheduler_type]
32-
(world_size, rank, local_world_size, local_rank) = (
33-
scheduler.get_parallel_configuration()
34-
)
37+
(world_size, rank, local_world_size,
38+
local_rank) = (scheduler.get_parallel_configuration())
3539

3640
# Check on the backend and report if the memory size was set
3741
backend = None
3842
device = None
3943
if torch.cuda.is_available():
4044
backend = "nccl"
4145
device = "cuda"
42-
fraction_max_gpu_mem = float(os.getenv("HPC_LAUNCHER_MAX_GPU_MEM", 1.0))
46+
fraction_max_gpu_mem = float(os.getenv("HPC_LAUNCHER_MAX_GPU_MEM",
47+
1.0))
4348
if fraction_max_gpu_mem != 1.0 and rank == 0:
4449
print(
4550
f"[Rank {rank} of {world_size}] TORCHRUN-HPC set the max GPU memory fraction to {fraction_max_gpu_mem}"
4651
)
4752
else:
4853
backend = "gloo"
49-
device="cpu"
54+
device = "cpu"
5055

5156
# Standard operating mode assumes that there is one rank per GPU
5257
# Check to see how many GPUS are actually available to this rank
5358
avail_gpus = 0
5459
gpus = []
55-
for e in ["CUDA_VISIBLE_DEVICES", "ROCR_VISIBLE_DEVICES", "HIP_VISIBLE_DEVICES"]:
60+
for e in [
61+
"CUDA_VISIBLE_DEVICES", "ROCR_VISIBLE_DEVICES",
62+
"HIP_VISIBLE_DEVICES"
63+
]:
5664
if os.getenv(e):
5765
gpus = os.getenv(e)
5866
break
@@ -97,21 +105,18 @@ def main():
97105
f"[Rank {rank} of {world_size}]: Initializing distributed PyTorch using protocol: {rdv_protocol}"
98106
)
99107
# TODO(later): Fix how we handle CUDA visible devices and MPI bind
100-
dist.init_process_group(
101-
backend, init_method=rdv_protocol, world_size=world_size, rank=rank, device_id=torch.device(device, local_device_id)
102-
)
108+
dist.init_process_group(backend,
109+
init_method=rdv_protocol,
110+
world_size=world_size,
111+
rank=rank,
112+
device_id=torch.device(
113+
device, local_device_id))
103114

104115
if rdv_protocol == "mpi://" and rank == 0:
105-
print(
106-
"[Rank {} of {}]: MPI Version: {}".format(
107-
rank, world_size, MPI.Get_version()
108-
)
109-
)
110-
print(
111-
"[Rank {} of {}]: MPI Implementation: {}".format(
112-
rank, world_size, MPI.Get_library_version()
113-
)
114-
)
116+
print("[Rank {} of {}]: MPI Version: {}".format(
117+
rank, world_size, MPI.Get_version()))
118+
print("[Rank {} of {}]: MPI Implementation: {}".format(
119+
rank, world_size, MPI.Get_library_version()))
115120

116121
# If the world size is only 1, torch distributed doesn't have to be initialized
117122
# however, the called application may try to setup torch distributed -- provide env variables
@@ -130,9 +135,13 @@ def main():
130135
os.environ["MASTER_PORT"] = "23456"
131136

132137
# Note that run_path will prepend the args[0] back onto the sys.argv so it needs to be stripped off first
133-
sys.argv = sys.argv[1:]
138+
sys.argv = sys.argv[1:] if not is_module else sys.argv[2:]
139+
134140
# Run underlying script
135-
runpy.run_path(args[0], run_name="__main__")
141+
if is_module:
142+
runpy.run_module(args[0], run_name="__main__", alter_sys=True)
143+
else:
144+
runpy.run_path(args[0], run_name="__main__")
136145

137146
if dist.is_initialized():
138147
# Deal with destroying the process group here

tests/e2e/relimport/__init__.py

Whitespace-only changes.

tests/e2e/relimport/__main__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import sys
2+
from .subfolder.b import g
3+
4+
if __name__ == "__main__":
5+
val = float(sys.argv[1])
6+
print(g(val))

tests/e2e/relimport/a.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
2+
3+
def f(value: int) -> int:
4+
return value + 1

tests/e2e/relimport/subfolder/__init__.py

Whitespace-only changes.

tests/e2e/relimport/subfolder/b.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from ..a import f
2+
3+
def g(a: float) -> int:
4+
return f(int(a + 2.5))

tests/test_relative_import.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) 2014-2026, Lawrence Livermore National Security, LLC.
2+
# Produced at the Lawrence Livermore National Laboratory.
3+
# Written by the LBANN Research Team (B. Van Essen, et al.) listed in
4+
# the CONTRIBUTORS file. See the top-level LICENSE file for details.
5+
#
6+
# LLNL-CODE-697807.
7+
# All rights reserved.
8+
#
9+
# This file is part of LBANN: Livermore Big Artificial Neural Network
10+
# Toolkit. For details, see http://software.llnl.gov/LBANN or
11+
# https://github.com/LBANN and https://github.com/LLNL/LBANN.
12+
#
13+
# SPDX-License-Identifier: (Apache-2.0)
14+
import pytest
15+
16+
import subprocess
17+
import shutil
18+
import os
19+
import re
20+
import sys
21+
import shutil
22+
23+
from hpc_launcher.systems import autodetect
24+
from hpc_launcher.systems.lc.sierra_family import Sierra
25+
from hpc_launcher.schedulers import get_schedulers
26+
27+
28+
def test_torchrun_hpc_relimport():
29+
scheduler_type = "slurm"
30+
if ((scheduler_type == "slurm" and
31+
(not shutil.which("srun")
32+
or shutil.which("srun") and shutil.which("jsrun"))) or
33+
(scheduler_type == "flux" and
34+
(not shutil.which("flux") or not os.path.exists("/run/flux/local")))
35+
or (scheduler_type == "lsf" and not shutil.which("jsrun"))):
36+
pytest.skip("No distributed launcher found")
37+
38+
scheduler = get_schedulers()[scheduler_type]
39+
num_nodes_in_allocation = scheduler.num_nodes_in_allocation()
40+
if not num_nodes_in_allocation is None and num_nodes_in_allocation == 1:
41+
pytest.skip(
42+
"Executed inside of an allocation with insufficient resources")
43+
44+
try:
45+
import torch
46+
except (ImportError, ModuleNotFoundError):
47+
pytest.skip("torch not found")
48+
49+
cmd = [
50+
sys.executable,
51+
"-m",
52+
"hpc_launcher.cli.torchrun_hpc",
53+
"-l",
54+
"-v",
55+
"-N",
56+
"1",
57+
"-n",
58+
"1",
59+
"-m",
60+
"relimport",
61+
"4.75",
62+
]
63+
cwd = os.path.join(os.path.dirname(__file__), "e2e")
64+
proc = subprocess.run(cmd,
65+
universal_newlines=True,
66+
capture_output=True,
67+
cwd=cwd)
68+
exp_dir = None
69+
70+
assert proc.returncode == 0
71+
assert proc.stdout.strip() == "8"
72+
73+
if exp_dir:
74+
shutil.rmtree(exp_dir, ignore_errors=True)
75+
76+
77+
if __name__ == "__main__":
78+
test_torchrun_hpc_relimport()

0 commit comments

Comments
 (0)