Skip to content

Commit b3a2f62

Browse files
committed
Add support for torchrun to DGXC Executor
DGX Cloud uses the PyTorch Training Operator from KubeFlow under the hood to launch jobs. This handles many of the variables necessary for running distributed PyTorch jobs with torchrun and only a subset of settings are required to launch the job as the original default settings will conflict with the auto-configured setup from DGX Cloud. Signed-Off-By: Robert Clark <[email protected]>
1 parent 61bb965 commit b3a2f62

File tree

4 files changed

+33
-20
lines changed

4 files changed

+33
-20
lines changed

src/nemo_run/core/execution/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@
1717
from nemo_run.core.execution.skypilot import SkypilotExecutor
1818
from nemo_run.core.execution.slurm import SlurmExecutor
1919

20-
__all__ = ["LocalExecutor", "SlurmExecutor", "SkypilotExecutor"]
20+
__all__ = ["LocalExecutor", "SlurmExecutor", "SkypilotExecutor", "DGXCloudExecutor"]

src/nemo_run/core/execution/dgxcloud.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def create_distributed_job(
138138
return response
139139

140140
def launch(self, name: str, cmd: list[str]) -> tuple[str, str]:
141-
name = name.replace("_", "-") # to meet K8s requirements
141+
name = name.replace("_", "-").replace(".", "-") # to meet K8s requirements
142142
token = self.get_auth_token()
143143
if not token:
144144
raise RuntimeError("Failed to get auth token")
@@ -156,6 +156,12 @@ def launch(self, name: str, cmd: list[str]) -> tuple[str, str]:
156156
status = r_json["actualPhase"]
157157
return job_id, status
158158

159+
def nnodes(self) -> int:
160+
return self.nodes
161+
162+
def nproc_per_node(self) -> int:
163+
return self.gpus_per_node
164+
159165
def status(self, job_id: str) -> Optional[DGXCloudState]:
160166
url = f"{self.base_url}/workloads/distributed/{job_id}"
161167
token = self.get_auth_token()

src/nemo_run/run/torchx_backend/components/torchrun.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def torchrun(
5959
rdzv_backend: str = "c10d",
6060
mounts: Optional[list[str]] = None,
6161
debug: bool = False,
62+
dgxc: bool = False,
6263
) -> specs.AppDef:
6364
"""
6465
Distributed data parallel style application (one role, multi-replica).
@@ -92,6 +93,7 @@ def torchrun(
9293
mounts: mounts to mount into the worker environment/container (ex. type=<bind/volume>,src=/host,dst=/job[,readonly]).
9394
See scheduler documentation for more info.
9495
debug: whether to run with preset debug flags enabled
96+
dgxc: whether to use a subset of settings for DGX Cloud
9597
"""
9698
if (script is None) == (m is None):
9799
raise ValueError("exactly one of --script and -m must be specified")
@@ -130,24 +132,27 @@ def torchrun(
130132
if debug:
131133
env.update(_TORCH_DEBUG_FLAGS)
132134

133-
cmd = [
134-
"--rdzv-backend",
135-
rdzv_backend,
136-
"--rdzv-endpoint",
137-
rdzv_endpoint,
138-
"--rdzv-id",
139-
f"{random.randint(1, 10000)}",
140-
"--nnodes",
141-
num_nodes,
142-
"--nproc-per-node",
143-
nproc_per_node,
144-
"--node-rank",
145-
node_rank,
146-
"--tee",
147-
"3",
148-
# "--role",
149-
# "",
150-
]
135+
if dgxc:
136+
cmd = ["--nnodes", nnodes_rep, "--nproc-per-node", nproc_per_node]
137+
else:
138+
cmd = [
139+
"--rdzv-backend",
140+
rdzv_backend,
141+
"--rdzv-endpoint",
142+
rdzv_endpoint,
143+
"--rdzv-id",
144+
f"{random.randint(1, 10000)}",
145+
"--nnodes",
146+
num_nodes,
147+
"--nproc-per-node",
148+
nproc_per_node,
149+
"--node-rank",
150+
node_rank,
151+
"--tee",
152+
"3",
153+
# "--role",
154+
# "",
155+
]
151156
if script is not None:
152157
if no_python:
153158
cmd += ["--no-python"]

src/nemo_run/run/torchx_backend/packaging.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from nemo_run.config import SCRIPTS_DIR, Partial, Script
2525
from nemo_run.core.execution.base import Executor, FaultTolerance, Torchrun
26+
from nemo_run.core.execution.dgxcloud import DGXCloudExecutor
2627
from nemo_run.core.execution.local import LocalExecutor
2728
from nemo_run.core.serialization.yaml import YamlSerializer
2829
from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer
@@ -139,6 +140,7 @@ def package(
139140
mounts=mounts,
140141
debug=executor.packager.debug,
141142
max_retries=executor.retries,
143+
dgxc=isinstance(executor, DGXCloudExecutor),
142144
)
143145
elif launcher and isinstance(launcher, FaultTolerance):
144146
app_def = ft_launcher.ft_launcher(

0 commit comments

Comments
 (0)