Skip to content

Commit 463d6e1

Browse files
hemildesairoclarkpablo-garay
authored
Add DGXCloud executor based on Run.ai REST API (#141)
* Add DGXCloudExecutor Signed-off-by: Hemil Desai <[email protected]> * Add dgx cloud scheduler Signed-off-by: Hemil Desai <[email protected]> * Fix formatting Signed-off-by: Hemil Desai <[email protected]> * Add support for torchrun to DGXC Executor DGX Cloud uses the PyTorch Training Operator from KubeFlow under the hood to launch jobs. This handles many of the variables necessary for running distributed PyTorch jobs with torchrun and only a subset of settings are required to launch the job as the original default settings will conflict with the auto-configured setup from DGX Cloud. Signed-Off-By: Robert Clark <[email protected]> * Add missing import for DGXCloudExecutor Signed-off-by: Pablo Garay <[email protected]> --------- Signed-off-by: Hemil Desai <[email protected]> Signed-off-by: Robert Clark <[email protected]> Signed-off-by: Pablo Garay <[email protected]> Co-authored-by: Robert Clark <[email protected]> Co-authored-by: Pablo Garay <[email protected]>
1 parent 5ed6128 commit 463d6e1

File tree

9 files changed

+561
-21
lines changed

9 files changed

+561
-21
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ slurm_tunnel = "nemo_run.run.torchx_backend.schedulers.slurm:create_scheduler"
4343
skypilot = "nemo_run.run.torchx_backend.schedulers.skypilot:create_scheduler"
4444
local_persistent = "nemo_run.run.torchx_backend.schedulers.local:create_scheduler"
4545
docker_persistent = "nemo_run.run.torchx_backend.schedulers.docker:create_scheduler"
46+
dgx_cloud = "nemo_run.run.torchx_backend.schedulers.dgxcloud:create_scheduler"
4647

4748
[project.optional-dependencies]
4849
skypilot = [

src/nemo_run/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
Torchrun,
2424
import_executor,
2525
)
26+
from nemo_run.core.execution.dgxcloud import DGXCloudExecutor
2627
from nemo_run.core.execution.docker import DockerExecutor
2728
from nemo_run.core.execution.local import LocalExecutor
2829
from nemo_run.core.execution.skypilot import SkypilotExecutor
@@ -46,6 +47,7 @@
4647
"ConfigurableMixin",
4748
"DevSpace",
4849
"DockerExecutor",
50+
"DGXCloudExecutor",
4951
"dryrun_fn",
5052
"Executor",
5153
"import_executor",

src/nemo_run/core/execution/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,6 @@
1616
from nemo_run.core.execution.local import LocalExecutor
1717
from nemo_run.core.execution.skypilot import SkypilotExecutor
1818
from nemo_run.core.execution.slurm import SlurmExecutor
19+
from nemo_run.core.execution.dgxcloud import DGXCloudExecutor
1920

20-
__all__ = ["LocalExecutor", "SlurmExecutor", "SkypilotExecutor"]
21+
__all__ = ["LocalExecutor", "SlurmExecutor", "SkypilotExecutor", "DGXCloudExecutor"]
Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
import json
2+
import logging
3+
import os
4+
import subprocess
5+
from dataclasses import dataclass, field
6+
from enum import Enum
7+
from pathlib import Path
8+
from typing import Any, Optional, Type
9+
10+
import requests
11+
from invoke.context import Context
12+
13+
from nemo_run.core.execution.base import (
14+
Executor,
15+
ExecutorMacros,
16+
)
17+
from nemo_run.core.packaging.base import Packager
18+
from nemo_run.core.packaging.git import GitArchivePackager
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
class DGXCloudState(Enum):
24+
CREATING = "Creating"
25+
INITIALIZING = "Initializing"
26+
RESUMING = "Resuming"
27+
PENDING = "Pending"
28+
DELETING = "Deleting"
29+
RUNNING = "Running"
30+
UPDATING = "Updating"
31+
STOPPED = "Stopped"
32+
STOPPING = "Stopping"
33+
DEGRADED = "Degraded"
34+
FAILED = "Failed"
35+
COMPLETED = "Completed"
36+
TERMINATING = "Terminating"
37+
UNKNOWN = "Unknown"
38+
39+
40+
@dataclass(kw_only=True)
41+
class DGXCloudExecutor(Executor):
42+
"""
43+
Dataclass to configure a DGX Executor.
44+
45+
This executor integrates with a DGX cloud endpoint for launching jobs
46+
via a REST API. It acquires an auth token, identifies the project/cluster,
47+
and launches jobs with a specified command. It can be adapted to meet user
48+
authentication and job-submission requirements on DGX.
49+
"""
50+
51+
base_url: str
52+
app_id: str
53+
app_secret: str
54+
project_name: str
55+
container_image: str
56+
nodes: int = 1
57+
gpus_per_node: int = 8
58+
pvcs: list[dict[str, Any]] = field(default_factory=list)
59+
distributed_framework: str = "PyTorch"
60+
custom_spec: dict[str, Any] = field(default_factory=dict)
61+
62+
def get_auth_token(self) -> Optional[str]:
63+
url = f"{self.base_url}/token"
64+
payload = {
65+
"grantType": "app_token",
66+
"appId": self.app_id,
67+
"appSecret": self.app_secret,
68+
}
69+
70+
response = requests.post(url, json=payload, headers=self._default_headers())
71+
response_text = response.text.strip()
72+
auth_token = json.loads(response_text).get("accessToken", None) # [1]
73+
if not auth_token:
74+
logger.error("Failed to retrieve auth token; response was: %s", response_text)
75+
return None
76+
77+
logger.debug("Retrieved auth token from %s", url)
78+
return auth_token
79+
80+
def get_project_and_cluster_id(self, token: str) -> tuple[Optional[str], Optional[str]]:
81+
url = f"{self.base_url}/org-unit/projects"
82+
headers = self._default_headers(token=token)
83+
response = requests.get(url, headers=headers)
84+
projects = json.loads(response.text.strip()).get("projects", [])
85+
project_id, cluster_id = None, None
86+
for prj in projects:
87+
if not self.project_name or prj["name"] == self.project_name: # [2]
88+
project_id, cluster_id = prj["id"], prj["clusterId"]
89+
logger.debug(
90+
"Found project '%s' (%s) on cluster '%s'", prj["name"], project_id, cluster_id
91+
)
92+
break
93+
return project_id, cluster_id
94+
95+
def create_distributed_job(
96+
self, token: str, project_id: str, cluster_id: str, name: str, cmd: list[str]
97+
):
98+
"""
99+
Creates a distributed PyTorch job using the provided project/cluster IDs.
100+
"""
101+
url = f"{self.base_url}/workloads/distributed"
102+
headers = self._default_headers(token=token)
103+
launch_script = f"""
104+
ln -s {self.job_dir} /nemo_run
105+
cd /nemo_run/code
106+
{" ".join(cmd)}
107+
"""
108+
with open(os.path.join(self.job_dir, "launch_script.sh"), "w+") as f:
109+
f.write(launch_script)
110+
111+
payload = {
112+
"name": name,
113+
"useGivenNameAsPrefix": True,
114+
"projectId": project_id,
115+
"clusterId": cluster_id,
116+
"spec": {
117+
"command": f"/bin/bash {self.job_dir}/launch_script.sh",
118+
"image": self.container_image,
119+
"distributedFramework": self.distributed_framework,
120+
"minReplicas": self.nodes,
121+
"maxReplicas": self.nodes,
122+
"numWorkers": self.nodes,
123+
"compute": {"gpuDevicesRequest": self.gpus_per_node},
124+
"storage": {"pvc": self.pvcs},
125+
"environmentVariables": [
126+
{"name": key, "value": value} for key, value in self.env_vars.items()
127+
],
128+
**self.custom_spec,
129+
},
130+
}
131+
132+
response = requests.post(url, json=payload, headers=headers)
133+
logger.debug(
134+
"Created distributed job; response code=%s, content=%s",
135+
response.status_code,
136+
response.text.strip(),
137+
)
138+
return response
139+
140+
def launch(self, name: str, cmd: list[str]) -> tuple[str, str]:
141+
name = name.replace("_", "-").replace(".", "-") # to meet K8s requirements
142+
token = self.get_auth_token()
143+
if not token:
144+
raise RuntimeError("Failed to get auth token")
145+
146+
project_id, cluster_id = self.get_project_and_cluster_id(token)
147+
if not project_id or not cluster_id:
148+
raise RuntimeError("Unable to determine project/cluster IDs for job submission")
149+
150+
resp = self.create_distributed_job(token, project_id, cluster_id, name, cmd)
151+
if resp.status_code not in [200, 202]:
152+
raise RuntimeError(f"Failed to create job, status_code={resp.status_code}")
153+
154+
r_json = resp.json()
155+
job_id = r_json["workloadId"]
156+
status = r_json["actualPhase"]
157+
return job_id, status
158+
159+
def nnodes(self) -> int:
160+
return self.nodes
161+
162+
def nproc_per_node(self) -> int:
163+
return self.gpus_per_node
164+
165+
def status(self, job_id: str) -> Optional[DGXCloudState]:
166+
url = f"{self.base_url}/workloads/distributed/{job_id}"
167+
token = self.get_auth_token()
168+
if not token:
169+
logger.error("Failed to retrieve auth token for cancellation request.")
170+
return None
171+
172+
headers = self._default_headers(token=token)
173+
response = requests.get(url, headers=headers)
174+
if response.status_code != 200:
175+
return DGXCloudState("Unknown")
176+
177+
r_json = response.json()
178+
return DGXCloudState(r_json["actualPhase"])
179+
180+
def cancel(self, job_id: str):
181+
# Retrieve the authentication token for the REST calls
182+
token = self.get_auth_token()
183+
if not token:
184+
logger.error("Failed to retrieve auth token for cancellation request.")
185+
return
186+
187+
# Build the DELETE request to cancel the job
188+
url = f"{self.base_url}/workloads/distributed/{job_id}/suspend"
189+
headers = self._default_headers(token=token)
190+
191+
response = requests.get(url, headers=headers)
192+
if response.status_code >= 200 and response.status_code < 300:
193+
logger.info(
194+
"Successfully cancelled job %s on DGX with response code %d",
195+
job_id,
196+
response.status_code,
197+
)
198+
else:
199+
logger.error(
200+
"Failed to cancel job %s, response code=%d, reason=%s",
201+
job_id,
202+
response.status_code,
203+
response.text,
204+
)
205+
206+
@classmethod
207+
def logs(cls: Type["DGXCloudExecutor"], app_id: str, fallback_path: Optional[str]):
208+
logger.warning(
209+
"Logs not available for DGXCloudExecutor based jobs. Please visit the cluster UI to view the logs."
210+
)
211+
212+
def cleanup(self, handle: str): ...
213+
214+
def assign(
215+
self,
216+
exp_id: str,
217+
exp_dir: str,
218+
task_id: str,
219+
task_dir: str,
220+
):
221+
self.job_name = task_id
222+
self.experiment_dir = exp_dir
223+
self.job_dir = os.path.join(exp_dir, task_dir)
224+
self.experiment_id = exp_id
225+
os.makedirs(self.job_dir, exist_ok=True)
226+
assert any(
227+
map(
228+
lambda x: os.path.commonpath(
229+
[os.path.abspath(x["path"]), os.path.abspath(self.job_dir)]
230+
)
231+
== os.path.abspath(x["path"]),
232+
self.pvcs,
233+
)
234+
), f"Need to specify atleast one PVC containing {self.job_dir}.\nTo update job dir to a PVC path, you can set the NEMORUN_HOME env var."
235+
236+
def package(self, packager: Packager, job_name: str):
237+
assert self.experiment_id, "Executor not assigned to an experiment."
238+
if isinstance(packager, GitArchivePackager):
239+
output = subprocess.run(
240+
["git", "rev-parse", "--show-toplevel"],
241+
check=True,
242+
stdout=subprocess.PIPE,
243+
)
244+
path = output.stdout.splitlines()[0].decode()
245+
base_path = Path(path).absolute()
246+
else:
247+
base_path = Path(os.getcwd()).absolute()
248+
249+
local_pkg = packager.package(base_path, self.job_dir, job_name)
250+
local_code_extraction_path = os.path.join(self.job_dir, "code")
251+
ctx = Context()
252+
ctx.run(f"mkdir -p {local_code_extraction_path}")
253+
254+
if self.get_launcher().nsys_profile:
255+
remote_nsys_extraction_path = os.path.join(
256+
self.job_dir, self.get_launcher().nsys_folder
257+
)
258+
ctx.run(f"mkdir -p {remote_nsys_extraction_path}")
259+
if local_pkg:
260+
ctx.run(
261+
f"tar -xvzf {local_pkg} -C {local_code_extraction_path} --ignore-zeros", hide=True
262+
)
263+
264+
def macro_values(self) -> Optional[ExecutorMacros]:
265+
return None
266+
267+
def _default_headers(self, token: Optional[str] = None) -> dict:
268+
headers = {
269+
"Accept": "application/json",
270+
"Content-Type": "application/json",
271+
}
272+
if token:
273+
headers["Authorization"] = f"Bearer {token}"
274+
return headers

src/nemo_run/run/experiment.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
get_type_namespace,
5050
)
5151
from nemo_run.core.execution.base import Executor
52+
from nemo_run.core.execution.dgxcloud import DGXCloudExecutor
5253
from nemo_run.core.execution.docker import DockerExecutor
5354
from nemo_run.core.execution.local import LocalExecutor
5455
from nemo_run.core.execution.skypilot import SkypilotExecutor
@@ -182,8 +183,14 @@ class Experiment(ConfigurableMixin):
182183
nemo experiment logs {exp_id} 0
183184
nemo experiment cancel {exp_id} 0
184185
"""
185-
_PARALLEL_SUPPORTED_EXECUTORS = (SlurmExecutor, LocalExecutor, SkypilotExecutor, DockerExecutor)
186-
_DETACH_SUPPORTED_EXECUTORS = (SlurmExecutor, SkypilotExecutor)
186+
_PARALLEL_SUPPORTED_EXECUTORS = (
187+
SlurmExecutor,
188+
LocalExecutor,
189+
SkypilotExecutor,
190+
DockerExecutor,
191+
DGXCloudExecutor,
192+
)
193+
_DETACH_SUPPORTED_EXECUTORS = (SlurmExecutor, SkypilotExecutor, DGXCloudExecutor)
187194
_DEPENDENCY_SUPPORTED_EXECUTORS = (SlurmExecutor,)
188195
_RUNNER_DEPENDENT_EXECUTORS = (LocalExecutor,)
189196
_CONFIG_FILE = "_CONFIG"

src/nemo_run/run/torchx_backend/components/torchrun.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def torchrun(
5959
rdzv_backend: str = "c10d",
6060
mounts: Optional[list[str]] = None,
6161
debug: bool = False,
62+
dgxc: bool = False,
6263
) -> specs.AppDef:
6364
"""
6465
Distributed data parallel style application (one role, multi-replica).
@@ -92,6 +93,7 @@ def torchrun(
9293
mounts: mounts to mount into the worker environment/container (ex. type=<bind/volume>,src=/host,dst=/job[,readonly]).
9394
See scheduler documentation for more info.
9495
debug: whether to run with preset debug flags enabled
96+
dgxc: whether to use a subset of settings for DGX Cloud
9597
"""
9698
if (script is None) == (m is None):
9799
raise ValueError("exactly one of --script and -m must be specified")
@@ -130,24 +132,27 @@ def torchrun(
130132
if debug:
131133
env.update(_TORCH_DEBUG_FLAGS)
132134

133-
cmd = [
134-
"--rdzv-backend",
135-
rdzv_backend,
136-
"--rdzv-endpoint",
137-
rdzv_endpoint,
138-
"--rdzv-id",
139-
f"{random.randint(1, 10000)}",
140-
"--nnodes",
141-
num_nodes,
142-
"--nproc-per-node",
143-
nproc_per_node,
144-
"--node-rank",
145-
node_rank,
146-
"--tee",
147-
"3",
148-
# "--role",
149-
# "",
150-
]
135+
if dgxc:
136+
cmd = ["--nnodes", nnodes_rep, "--nproc-per-node", nproc_per_node]
137+
else:
138+
cmd = [
139+
"--rdzv-backend",
140+
rdzv_backend,
141+
"--rdzv-endpoint",
142+
rdzv_endpoint,
143+
"--rdzv-id",
144+
f"{random.randint(1, 10000)}",
145+
"--nnodes",
146+
num_nodes,
147+
"--nproc-per-node",
148+
nproc_per_node,
149+
"--node-rank",
150+
node_rank,
151+
"--tee",
152+
"3",
153+
# "--role",
154+
# "",
155+
]
151156
if script is not None:
152157
if no_python:
153158
cmd += ["--no-python"]

0 commit comments

Comments
 (0)