Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions nemo_run/core/execution/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def transform(self, cmd: list[str]) -> Optional[Script]: ...
class Torchrun(Launcher):
rdzv_backend: str = "c10d"
rdzv_port: int = 29500
rdzv_id: Optional[int] = None


@dataclass(kw_only=True)
Expand All @@ -56,6 +57,7 @@ class FaultTolerance(Launcher):
job_results_file: str = ""
rdzv_backend: str = "c10d"
rdzv_port: int = 29500
rdzv_id: Optional[int] = None
workload_check_interval: Optional[float] = None
initial_rank_heartbeat_timeout: Optional[float] = None
rank_heartbeat_timeout: Optional[float] = None
Expand Down
3 changes: 2 additions & 1 deletion nemo_run/core/execution/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class LocalExecutor(Executor):

#: Used by components like torchrun to deduce the number of tasks to launch.
ntasks_per_node: int = 1
nodes: int = 1

def assign(
self,
Expand All @@ -50,7 +51,7 @@ def assign(
self.job_dir = os.path.join(exp_dir, task_dir)

def nnodes(self) -> int:
return 1
return self.nodes

def nproc_per_node(self) -> int:
return self.ntasks_per_node
6 changes: 6 additions & 0 deletions nemo_run/run/torchx_backend/components/ft_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def ft_launcher(
max_retries: int = 0,
rdzv_port: int = 49450,
rdzv_backend: str = "c10d",
rdzv_id: Optional[int] = None,
mounts: Optional[list[str]] = None,
debug: bool = False,
workload_check_interval: Optional[float] = None,
Expand All @@ -48,6 +49,8 @@ def ft_launcher(
rank_termination_signal: Optional[str] = None,
log_level: Optional[str] = None,
max_restarts: Optional[int] = None,
dgxc: bool = False,
use_env: bool = False,
) -> specs.AppDef:
torchrun_component = torchrun.torchrun(
*script_args,
Expand All @@ -63,10 +66,13 @@ def ft_launcher(
j=j,
rdzv_backend=rdzv_backend,
rdzv_port=rdzv_port,
rdzv_id=rdzv_id,
env=env,
mounts=mounts,
debug=debug,
max_retries=max_retries,
dgxc=dgxc,
use_env=use_env,
)

ft_args = []
Expand Down
28 changes: 17 additions & 11 deletions nemo_run/run/torchx_backend/components/torchrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,11 @@ def torchrun(
max_retries: int = 0,
rdzv_port: int = 49450,
rdzv_backend: str = "c10d",
rdzv_id: Optional[int] = None,
mounts: Optional[list[str]] = None,
debug: bool = False,
dgxc: bool = False,
use_env: bool = False,
) -> specs.AppDef:
"""
Distributed data parallel style application (one role, multi-replica).
Expand Down Expand Up @@ -113,17 +115,21 @@ def torchrun(
nproc_per_node = str(nproc_per_node)
node_rank = "0"
else:
# for multi-node, rely on the rank0_env environment variable set by
# the schedulers (see scheduler implementation for the actual env var this maps to)
# some schedulers (e.g. aws batch) make the rank0's ip-addr available on all BUT on rank0
# so default to "localhost" if the env var is not set or is empty
# rdzv_endpoint bash resolves to something to the effect of
# ${TORCHX_RANK0_HOST:=localhost}:29500
# use $$ in the prefix to escape the '$' literal (rather than a string Template substitution argument)
rdzv_endpoint = torchx_dist._noquote(f"$${ExecutorMacros.HEAD_NODE_IP_VAR}:{rdzv_port}")
num_nodes = torchx_dist._noquote(f"$${ExecutorMacros.NUM_NODES_VAR}")
if use_env and os.getenv("MASTER_ADDR") and os.getenv("MASTER_PORT"):
master_addr = os.environ["MASTER_ADDR"]
master_port = os.environ["MASTER_PORT"]
rdzv_endpoint = torchx_dist._noquote(master_addr + ":" + master_port)
random.seed(rdzv_id)
else:
rdzv_endpoint = torchx_dist._noquote(f"$${ExecutorMacros.HEAD_NODE_IP_VAR}:{rdzv_port}")

num_nodes = nnodes_rep
nproc_per_node = str(nproc_per_node)
node_rank = torchx_dist._noquote(f"$${ExecutorMacros.NODE_RANK_VAR}")

if use_env and os.getenv("NODE_RANK"):
node_rank = os.environ["NODE_RANK"]
else:
node_rank = torchx_dist._noquote(f"$${ExecutorMacros.NODE_RANK_VAR}")

if env is None:
env = {}
Expand All @@ -141,7 +147,7 @@ def torchrun(
"--rdzv-endpoint",
rdzv_endpoint,
"--rdzv-id",
f"{random.randint(1, 10000)}",
f"{rdzv_id or random.randint(1, 10000)}",
"--nnodes",
num_nodes,
"--nproc-per-node",
Expand Down
5 changes: 5 additions & 0 deletions nemo_run/run/torchx_backend/packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def _get_details_from_script(fn_or_script: Script, serialize_configs: bool):
transformed_script, serialize_configs=False
)

use_env = isinstance(executor, LocalExecutor)
if launcher and isinstance(launcher, Torchrun):
app_def = torchrun.torchrun(
*args,
Expand All @@ -160,11 +161,13 @@ def _get_details_from_script(fn_or_script: Script, serialize_configs: bool):
j=f"{executor.nnodes()}x{executor.nproc_per_node()}",
rdzv_backend=launcher.rdzv_backend,
rdzv_port=launcher.rdzv_port,
rdzv_id=launcher.rdzv_id,
env=env,
mounts=mounts,
debug=executor.packager.debug,
max_retries=executor.retries,
dgxc=isinstance(executor, DGXCloudExecutor),
use_env=use_env,
)
elif launcher and isinstance(launcher, FaultTolerance):
app_def = ft_launcher.ft_launcher(
Expand All @@ -181,6 +184,7 @@ def _get_details_from_script(fn_or_script: Script, serialize_configs: bool):
j=f"{executor.nnodes()}x{executor.nproc_per_node()}",
rdzv_backend=launcher.rdzv_backend,
rdzv_port=launcher.rdzv_port,
rdzv_id=launcher.rdzv_id,
env=env,
mounts=mounts,
debug=executor.packager.debug,
Expand All @@ -191,6 +195,7 @@ def _get_details_from_script(fn_or_script: Script, serialize_configs: bool):
log_level=launcher.log_level,
max_retries=executor.retries,
max_restarts=launcher.max_restarts,
use_env=use_env,
)
else:
app_def = specs.AppDef(
Expand Down
2 changes: 1 addition & 1 deletion test/run/torchx_backend/test_packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def test_package_torchrun(mock_executor):
"--rdzv-id",
"1",
"--nnodes",
"$$${num_nodes_var}",
"2",
"--nproc-per-node",
"1",
"--node-rank",
Expand Down
Loading