Skip to content

Commit 33c0e0b

Browse files
authored
Support torchrun multi node on local executor (#143)
* Support torchrun multi node on local executor Signed-off-by: Hemil Desai <[email protected]> * fix Signed-off-by: Hemil Desai <[email protected]> * fix Signed-off-by: Hemil Desai <[email protected]> --------- Signed-off-by: Hemil Desai <[email protected]>
1 parent abccf6e commit 33c0e0b

File tree

6 files changed

+33
-13
lines changed

6 files changed

+33
-13
lines changed

nemo_run/core/execution/launcher.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def transform(self, cmd: list[str]) -> Optional[Script]: ...
4747
class Torchrun(Launcher):
4848
rdzv_backend: str = "c10d"
4949
rdzv_port: int = 29500
50+
rdzv_id: Optional[int] = None
5051

5152

5253
@dataclass(kw_only=True)
@@ -56,6 +57,7 @@ class FaultTolerance(Launcher):
5657
job_results_file: str = ""
5758
rdzv_backend: str = "c10d"
5859
rdzv_port: int = 29500
60+
rdzv_id: Optional[int] = None
5961
workload_check_interval: Optional[float] = None
6062
initial_rank_heartbeat_timeout: Optional[float] = None
6163
rank_heartbeat_timeout: Optional[float] = None

nemo_run/core/execution/local.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class LocalExecutor(Executor):
3737

3838
#: Used by components like torchrun to deduce the number of tasks to launch.
3939
ntasks_per_node: int = 1
40+
nodes: int = 1
4041

4142
def assign(
4243
self,
@@ -50,7 +51,7 @@ def assign(
5051
self.job_dir = os.path.join(exp_dir, task_dir)
5152

5253
def nnodes(self) -> int:
53-
return 1
54+
return self.nodes
5455

5556
def nproc_per_node(self) -> int:
5657
return self.ntasks_per_node

nemo_run/run/torchx_backend/components/ft_launcher.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def ft_launcher(
4040
max_retries: int = 0,
4141
rdzv_port: int = 49450,
4242
rdzv_backend: str = "c10d",
43+
rdzv_id: Optional[int] = None,
4344
mounts: Optional[list[str]] = None,
4445
debug: bool = False,
4546
workload_check_interval: Optional[float] = None,
@@ -48,6 +49,8 @@ def ft_launcher(
4849
rank_termination_signal: Optional[str] = None,
4950
log_level: Optional[str] = None,
5051
max_restarts: Optional[int] = None,
52+
dgxc: bool = False,
53+
use_env: bool = False,
5154
) -> specs.AppDef:
5255
torchrun_component = torchrun.torchrun(
5356
*script_args,
@@ -63,10 +66,13 @@ def ft_launcher(
6366
j=j,
6467
rdzv_backend=rdzv_backend,
6568
rdzv_port=rdzv_port,
69+
rdzv_id=rdzv_id,
6670
env=env,
6771
mounts=mounts,
6872
debug=debug,
6973
max_retries=max_retries,
74+
dgxc=dgxc,
75+
use_env=use_env,
7076
)
7177

7278
ft_args = []

nemo_run/run/torchx_backend/components/torchrun.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,11 @@ def torchrun(
5757
max_retries: int = 0,
5858
rdzv_port: int = 49450,
5959
rdzv_backend: str = "c10d",
60+
rdzv_id: Optional[int] = None,
6061
mounts: Optional[list[str]] = None,
6162
debug: bool = False,
6263
dgxc: bool = False,
64+
use_env: bool = False,
6365
) -> specs.AppDef:
6466
"""
6567
Distributed data parallel style application (one role, multi-replica).
@@ -113,17 +115,21 @@ def torchrun(
113115
nproc_per_node = str(nproc_per_node)
114116
node_rank = "0"
115117
else:
116-
# for multi-node, rely on the rank0_env environment variable set by
117-
# the schedulers (see scheduler implementation for the actual env var this maps to)
118-
# some schedulers (e.g. aws batch) make the rank0's ip-addr available on all BUT on rank0
119-
# so default to "localhost" if the env var is not set or is empty
120-
# rdzv_endpoint bash resolves to something to the effect of
121-
# ${TORCHX_RANK0_HOST:=localhost}:29500
122-
# use $$ in the prefix to escape the '$' literal (rather than a string Template substitution argument)
123-
rdzv_endpoint = torchx_dist._noquote(f"$${ExecutorMacros.HEAD_NODE_IP_VAR}:{rdzv_port}")
124-
num_nodes = torchx_dist._noquote(f"$${ExecutorMacros.NUM_NODES_VAR}")
118+
if use_env and os.getenv("MASTER_ADDR") and os.getenv("MASTER_PORT"):
119+
master_addr = os.environ["MASTER_ADDR"]
120+
master_port = os.environ["MASTER_PORT"]
121+
rdzv_endpoint = torchx_dist._noquote(master_addr + ":" + master_port)
122+
random.seed(rdzv_id)
123+
else:
124+
rdzv_endpoint = torchx_dist._noquote(f"$${ExecutorMacros.HEAD_NODE_IP_VAR}:{rdzv_port}")
125+
126+
num_nodes = nnodes_rep
125127
nproc_per_node = str(nproc_per_node)
126-
node_rank = torchx_dist._noquote(f"$${ExecutorMacros.NODE_RANK_VAR}")
128+
129+
if use_env and os.getenv("NODE_RANK"):
130+
node_rank = os.environ["NODE_RANK"]
131+
else:
132+
node_rank = torchx_dist._noquote(f"$${ExecutorMacros.NODE_RANK_VAR}")
127133

128134
if env is None:
129135
env = {}
@@ -141,7 +147,7 @@ def torchrun(
141147
"--rdzv-endpoint",
142148
rdzv_endpoint,
143149
"--rdzv-id",
144-
f"{random.randint(1, 10000)}",
150+
f"{rdzv_id or random.randint(1, 10000)}",
145151
"--nnodes",
146152
num_nodes,
147153
"--nproc-per-node",

nemo_run/run/torchx_backend/packaging.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def _get_details_from_script(fn_or_script: Script, serialize_configs: bool):
145145
transformed_script, serialize_configs=False
146146
)
147147

148+
use_env = isinstance(executor, LocalExecutor)
148149
if launcher and isinstance(launcher, Torchrun):
149150
app_def = torchrun.torchrun(
150151
*args,
@@ -160,11 +161,13 @@ def _get_details_from_script(fn_or_script: Script, serialize_configs: bool):
160161
j=f"{executor.nnodes()}x{executor.nproc_per_node()}",
161162
rdzv_backend=launcher.rdzv_backend,
162163
rdzv_port=launcher.rdzv_port,
164+
rdzv_id=launcher.rdzv_id,
163165
env=env,
164166
mounts=mounts,
165167
debug=executor.packager.debug,
166168
max_retries=executor.retries,
167169
dgxc=isinstance(executor, DGXCloudExecutor),
170+
use_env=use_env,
168171
)
169172
elif launcher and isinstance(launcher, FaultTolerance):
170173
app_def = ft_launcher.ft_launcher(
@@ -181,6 +184,7 @@ def _get_details_from_script(fn_or_script: Script, serialize_configs: bool):
181184
j=f"{executor.nnodes()}x{executor.nproc_per_node()}",
182185
rdzv_backend=launcher.rdzv_backend,
183186
rdzv_port=launcher.rdzv_port,
187+
rdzv_id=launcher.rdzv_id,
184188
env=env,
185189
mounts=mounts,
186190
debug=executor.packager.debug,
@@ -191,6 +195,7 @@ def _get_details_from_script(fn_or_script: Script, serialize_configs: bool):
191195
log_level=launcher.log_level,
192196
max_retries=executor.retries,
193197
max_restarts=launcher.max_restarts,
198+
use_env=use_env,
194199
)
195200
else:
196201
app_def = specs.AppDef(

test/run/torchx_backend/test_packaging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def test_package_torchrun(mock_executor):
207207
"--rdzv-id",
208208
"1",
209209
"--nnodes",
210-
"$$${num_nodes_var}",
210+
"2",
211211
"--nproc-per-node",
212212
"1",
213213
"--node-rank",

0 commit comments

Comments
 (0)