Skip to content

Commit 1ff2194

Browse files
committed
Add DGXCloudExecutor
Signed-off-by: Hemil Desai <hemild@nvidia.com>
1 parent 11e0d2f commit 1ff2194

File tree

2 files changed

+260
-0
lines changed

2 files changed

+260
-0
lines changed

src/nemo_run/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
Torchrun,
2424
import_executor,
2525
)
26+
from nemo_run.core.execution.dgxcloud import DGXCloudExecutor
2627
from nemo_run.core.execution.docker import DockerExecutor
2728
from nemo_run.core.execution.local import LocalExecutor
2829
from nemo_run.core.execution.skypilot import SkypilotExecutor
@@ -46,6 +47,7 @@
4647
"ConfigurableMixin",
4748
"DevSpace",
4849
"DockerExecutor",
50+
"DGXCloudExecutor",
4951
"dryrun_fn",
5052
"Executor",
5153
"import_executor",
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
import json
2+
import logging
3+
import os
4+
import subprocess
5+
from dataclasses import dataclass, field
6+
from pathlib import Path
7+
from typing import Any, Optional
8+
9+
import requests
10+
from invoke.context import Context
11+
12+
from nemo_run.core.execution.base import (
13+
Executor,
14+
ExecutorMacros,
15+
)
16+
from nemo_run.core.packaging.base import Packager
17+
from nemo_run.core.packaging.git import GitArchivePackager
18+
19+
logger = logging.getLogger(__name__)
20+
21+
22+
@dataclass(kw_only=True)
23+
class DGXCloudExecutor(Executor):
24+
"""
25+
Dataclass to configure a DGX Executor.
26+
27+
This executor integrates with a DGX cloud endpoint for launching jobs
28+
via a REST API. It acquires an auth token, identifies the project/cluster,
29+
and launches jobs with a specified command. It can be adapted to meet user
30+
authentication and job-submission requirements on DGX.
31+
32+
Example usage might include specifying the environment variables or secrets
33+
needed to create new distributed training jobs and storing user-specified
34+
configuration (cluster URL, project name, application secrets, etc.).
35+
"""
36+
37+
base_url: str
38+
app_id: str
39+
app_secret: str
40+
project_name: str
41+
job_name: str
42+
container_image: str
43+
nodes: int = 1
44+
gpus_per_node: int = 8
45+
pvcs: list[dict[str, Any]] = field(default_factory=list)
46+
distributed_framework: str = "PyTorch"
47+
custom_spec: dict[str, Any] = field(default_factory=dict)
48+
49+
def __post_init__(self):
50+
self.job_name = self.job_name.replace("_", "-")
51+
52+
def get_auth_token(self) -> Optional[str]:
53+
"""
54+
Retrieves the authorization token from the endpoint. Required for subsequent
55+
calls to create distributed jobs on the DGX platform.
56+
"""
57+
url = f"{self.base_url}/token"
58+
payload = {
59+
"grantType": "app_token",
60+
"appId": self.app_id,
61+
"appSecret": self.app_secret,
62+
}
63+
64+
response = requests.post(url, json=payload, headers=self._default_headers())
65+
response_text = response.text.strip()
66+
auth_token = json.loads(response_text).get("accessToken", None) # [1]
67+
if not auth_token:
68+
logger.error("Failed to retrieve auth token; response was: %s", response_text)
69+
return None
70+
71+
logger.debug("Retrieved auth token from %s", url)
72+
return auth_token
73+
74+
def get_project_and_cluster_id(self, token: str) -> tuple[Optional[str], Optional[str]]:
75+
"""
76+
Retrieves the project ID and cluster ID by matching the user-provided
77+
project_name to the result from the DGX API. Returns (project_id, cluster_id).
78+
"""
79+
url = f"{self.base_url}/org-unit/projects"
80+
headers = self._default_headers(token=token)
81+
response = requests.get(url, headers=headers)
82+
projects = json.loads(response.text.strip()).get("projects", [])
83+
project_id, cluster_id = None, None
84+
for prj in projects:
85+
if not self.project_name or prj["name"] == self.project_name: # [2]
86+
project_id, cluster_id = prj["id"], prj["clusterId"]
87+
logger.debug(
88+
"Found project '%s' (%s) on cluster '%s'", prj["name"], project_id, cluster_id
89+
)
90+
break
91+
return project_id, cluster_id
92+
93+
def create_distributed_job(self, token: str, project_id: str, cluster_id: str):
94+
"""
95+
Creates a distributed PyTorch job using the provided project/cluster IDs.
96+
"""
97+
url = f"{self.base_url}/workloads/distributed"
98+
headers = self._default_headers(token=token)
99+
payload = {
100+
"name": self.job_name,
101+
"useGivenNameAsPrefix": True,
102+
"projectId": project_id,
103+
"clusterId": cluster_id,
104+
"spec": {
105+
"command": "echo 'hello' && sleep 60 && echo 'goodbye'",
106+
# "args": f"""
107+
# # ln -s {self.job_dir} /nemo_run
108+
# echo "Hello"
109+
# sleep 600
110+
# echo "Goodbye"
111+
# """,
112+
"image": self.container_image,
113+
# "workingDir": "/nemo_run/code",
114+
"distributedFramework": self.distributed_framework,
115+
"minReplicas": self.nodes,
116+
"maxReplicas": self.nodes,
117+
"numWorkers": self.nodes,
118+
"compute": {"gpuDevicesRequest": self.gpus_per_node},
119+
"storage": {"pvc": self.pvcs},
120+
"environmentVariables": [
121+
{"name": key, "value": value} for key, value in self.env_vars.items()
122+
],
123+
**self.custom_spec,
124+
},
125+
}
126+
127+
response = requests.post(url, json=payload, headers=headers)
128+
logger.debug(
129+
"Created distributed job; response code=%s, content=%s",
130+
response.status_code,
131+
response.text.strip(),
132+
)
133+
return response
134+
135+
def launch(self, *args, **kwargs) -> tuple[Optional[str], Optional[str]]:
136+
"""
137+
Core entry point to create a token, get the project/cluster, and launch
138+
the distributed job on the DGX platform.
139+
Returns (job_id, handle) to align with the typical Nemo-Run Executor pattern.
140+
"""
141+
token = self.get_auth_token()
142+
if not token:
143+
logger.error("Cannot proceed without auth token")
144+
return None, None
145+
146+
project_id, cluster_id = self.get_project_and_cluster_id(token)
147+
if not project_id or not cluster_id:
148+
logger.error("Unable to determine project/cluster IDs for job submission")
149+
return None, None
150+
151+
resp = self.create_distributed_job(token, project_id, cluster_id)
152+
if resp.status_code not in [200, 202]:
153+
logger.error("Failed to create job, status_code=%s", resp.status_code)
154+
return None, None
155+
156+
# For demonstration, parse out some job ID from the response if available
157+
try:
158+
r_json = resp.json()
159+
job_id = r_json.get("id", "dgx_job_id") # Example ID key
160+
except Exception: # If the response is not valid JSON or no "id"
161+
job_id = "dgx_job_id"
162+
163+
# Typically in Nemo-Run, "handle" can store information for references
164+
handle = f"dgx://{job_id}"
165+
return job_id, handle
166+
167+
def status(self, app_id: str) -> tuple[Optional[str], Optional[dict]]:
168+
"""
169+
Return the job status from the DGX platform. The app_id might be used
170+
to query the job ID stored at creation time. For demonstration, this is
171+
left abstract, as the API for status queries can be matched to user needs.
172+
"""
173+
logger.debug("Getting status for app_id=%s", app_id) # [1]
174+
# If a specialized endpoint exists, you would call it here, e.g.:
175+
# GET <base_url>/workloads/<job_id>
176+
return None, None
177+
178+
def cancel(self, app_id: str):
179+
"""
180+
Cancels the job on the DGX platform. Typically, you'd parse the job_id
181+
from app_id and call the relevant REST endpoint to delete/cancel the job.
182+
"""
183+
logger.debug("Attempt to cancel job for app_id=%s", app_id)
184+
185+
def logs(self, app_id: str, fallback_path: Optional[str]):
186+
"""
187+
Prints or fetches logs for the job. Typically, you'd parse the job_id
188+
from app_id and query a logs endpoint. Fallback logic can be implemented
189+
if logs must be fetched from a known file path.
190+
"""
191+
192+
def cleanup(self, handle: str):
193+
"""
194+
Performs any necessary cleanup after the job has completed.
195+
"""
196+
197+
def assign(
198+
self,
199+
exp_id: str,
200+
exp_dir: str,
201+
task_id: str,
202+
task_dir: str,
203+
):
204+
"""
205+
Assigns the job to a specific experiment run directory in Nemo-Run.
206+
"""
207+
self.job_name = task_id
208+
self.experiment_dir = exp_dir
209+
self.job_dir = os.path.join(exp_dir, task_dir)
210+
self.experiment_id = exp_id
211+
os.makedirs(self.job_dir, exist_ok=True)
212+
assert any(
213+
map(lambda x: Path(self.job_dir).relative_to(Path(x["path"])), self.pvcs)
214+
), f"Need to specify atleast one PVC matching {self.job_dir}"
215+
216+
def package(self, packager: Packager, job_name: str):
217+
assert self.experiment_id, "Executor not assigned to an experiment."
218+
if isinstance(packager, GitArchivePackager):
219+
output = subprocess.run(
220+
["git", "rev-parse", "--show-toplevel"],
221+
check=True,
222+
stdout=subprocess.PIPE,
223+
)
224+
path = output.stdout.splitlines()[0].decode()
225+
base_path = Path(path).absolute()
226+
else:
227+
base_path = Path(os.getcwd()).absolute()
228+
229+
local_pkg = packager.package(base_path, self.job_dir, job_name)
230+
local_code_extraction_path = os.path.join(self.job_dir, "code")
231+
ctx = Context()
232+
ctx.run(f"mkdir -p {local_code_extraction_path}")
233+
234+
if self.get_launcher().nsys_profile:
235+
remote_nsys_extraction_path = os.path.join(
236+
self.job_dir, self.get_launcher().nsys_folder
237+
)
238+
ctx.run(f"mkdir -p {remote_nsys_extraction_path}")
239+
if local_pkg:
240+
ctx.run(
241+
f"tar -xvzf {local_pkg} -C {local_code_extraction_path} --ignore-zeros", hide=True
242+
)
243+
244+
def macro_values(self) -> Optional[ExecutorMacros]:
245+
"""
246+
Returns environment macros for distributed training. Not strictly used in this
247+
example, but can configure advanced key-value pairs for the job environment.
248+
"""
249+
return None
250+
251+
def _default_headers(self, token: Optional[str] = None) -> dict:
252+
headers = {
253+
"Accept": "application/json",
254+
"Content-Type": "application/json",
255+
}
256+
if token:
257+
headers["Authorization"] = f"Bearer {token}"
258+
return headers

0 commit comments

Comments
 (0)