Skip to content

Commit 570f577

Browse files
authored
Merge pull request #145 from roclark/roclark/dgxc-executor-torchrun
Add support for torchrun to DGXC Executor
2 parents 61bb965 + b3a2f62 commit 570f577

File tree

4 files changed

+33
-20
lines changed

4 files changed

+33
-20
lines changed

src/nemo_run/core/execution/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@
1717
from nemo_run.core.execution.skypilot import SkypilotExecutor
1818
from nemo_run.core.execution.slurm import SlurmExecutor
1919

20-
__all__ = ["LocalExecutor", "SlurmExecutor", "SkypilotExecutor"]
20+
__all__ = ["LocalExecutor", "SlurmExecutor", "SkypilotExecutor", "DGXCloudExecutor"]

src/nemo_run/core/execution/dgxcloud.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def create_distributed_job(
138138
return response
139139

140140
def launch(self, name: str, cmd: list[str]) -> tuple[str, str]:
141-
name = name.replace("_", "-") # to meet K8s requirements
141+
name = name.replace("_", "-").replace(".", "-") # to meet K8s requirements
142142
token = self.get_auth_token()
143143
if not token:
144144
raise RuntimeError("Failed to get auth token")
@@ -156,6 +156,12 @@ def launch(self, name: str, cmd: list[str]) -> tuple[str, str]:
156156
status = r_json["actualPhase"]
157157
return job_id, status
158158

159+
def nnodes(self) -> int:
160+
return self.nodes
161+
162+
def nproc_per_node(self) -> int:
163+
return self.gpus_per_node
164+
159165
def status(self, job_id: str) -> Optional[DGXCloudState]:
160166
url = f"{self.base_url}/workloads/distributed/{job_id}"
161167
token = self.get_auth_token()

src/nemo_run/run/torchx_backend/components/torchrun.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def torchrun(
5959
rdzv_backend: str = "c10d",
6060
mounts: Optional[list[str]] = None,
6161
debug: bool = False,
62+
dgxc: bool = False,
6263
) -> specs.AppDef:
6364
"""
6465
Distributed data parallel style application (one role, multi-replica).
@@ -92,6 +93,7 @@ def torchrun(
9293
mounts: mounts to mount into the worker environment/container (ex. type=<bind/volume>,src=/host,dst=/job[,readonly]).
9394
See scheduler documentation for more info.
9495
debug: whether to run with preset debug flags enabled
96+
dgxc: whether to use a subset of settings for DGX Cloud
9597
"""
9698
if (script is None) == (m is None):
9799
raise ValueError("exactly one of --script and -m must be specified")
@@ -130,24 +132,27 @@ def torchrun(
130132
if debug:
131133
env.update(_TORCH_DEBUG_FLAGS)
132134

133-
cmd = [
134-
"--rdzv-backend",
135-
rdzv_backend,
136-
"--rdzv-endpoint",
137-
rdzv_endpoint,
138-
"--rdzv-id",
139-
f"{random.randint(1, 10000)}",
140-
"--nnodes",
141-
num_nodes,
142-
"--nproc-per-node",
143-
nproc_per_node,
144-
"--node-rank",
145-
node_rank,
146-
"--tee",
147-
"3",
148-
# "--role",
149-
# "",
150-
]
135+
if dgxc:
136+
cmd = ["--nnodes", nnodes_rep, "--nproc-per-node", nproc_per_node]
137+
else:
138+
cmd = [
139+
"--rdzv-backend",
140+
rdzv_backend,
141+
"--rdzv-endpoint",
142+
rdzv_endpoint,
143+
"--rdzv-id",
144+
f"{random.randint(1, 10000)}",
145+
"--nnodes",
146+
num_nodes,
147+
"--nproc-per-node",
148+
nproc_per_node,
149+
"--node-rank",
150+
node_rank,
151+
"--tee",
152+
"3",
153+
# "--role",
154+
# "",
155+
]
151156
if script is not None:
152157
if no_python:
153158
cmd += ["--no-python"]

src/nemo_run/run/torchx_backend/packaging.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from nemo_run.config import SCRIPTS_DIR, Partial, Script
2525
from nemo_run.core.execution.base import Executor, FaultTolerance, Torchrun
26+
from nemo_run.core.execution.dgxcloud import DGXCloudExecutor
2627
from nemo_run.core.execution.local import LocalExecutor
2728
from nemo_run.core.serialization.yaml import YamlSerializer
2829
from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer
@@ -139,6 +140,7 @@ def package(
139140
mounts=mounts,
140141
debug=executor.packager.debug,
141142
max_retries=executor.retries,
143+
dgxc=isinstance(executor, DGXCloudExecutor),
142144
)
143145
elif launcher and isinstance(launcher, FaultTolerance):
144146
app_def = ft_launcher.ft_launcher(

0 commit comments

Comments
 (0)