Skip to content

Commit e4fecfd

Browse files
committed
Add dgx cloud scheduler
Signed-off-by: Hemil Desai <[email protected]>
1 parent 1ff2194 commit e4fecfd

File tree

5 files changed

+335
-85
lines changed

5 files changed

+335
-85
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ slurm_tunnel = "nemo_run.run.torchx_backend.schedulers.slurm:create_scheduler"
4343
skypilot = "nemo_run.run.torchx_backend.schedulers.skypilot:create_scheduler"
4444
local_persistent = "nemo_run.run.torchx_backend.schedulers.local:create_scheduler"
4545
docker_persistent = "nemo_run.run.torchx_backend.schedulers.docker:create_scheduler"
46+
dgx_cloud = "nemo_run.run.torchx_backend.schedulers.dgxcloud:create_scheduler"
4647

4748
[project.optional-dependencies]
4849
skypilot = [

src/nemo_run/core/execution/dgxcloud.py

Lines changed: 81 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import subprocess
55
from dataclasses import dataclass, field
66
from pathlib import Path
7-
from typing import Any, Optional
7+
from typing import Any, Optional, Type
88

99
import requests
1010
from invoke.context import Context
@@ -18,6 +18,24 @@
1818

1919
logger = logging.getLogger(__name__)
2020

21+
from enum import Enum
22+
23+
class DGXCloudState(Enum):
24+
CREATING = "Creating"
25+
INITIALIZING = "Initializing"
26+
RESUMING = "Resuming"
27+
PENDING = "Pending"
28+
DELETING = "Deleting"
29+
RUNNING = "Running"
30+
UPDATING = "Updating"
31+
STOPPED = "Stopped"
32+
STOPPING = "Stopping"
33+
DEGRADED = "Degraded"
34+
FAILED = "Failed"
35+
COMPLETED = "Completed"
36+
TERMINATING = "Terminating"
37+
UNKNOWN = "Unknown"
38+
2139

2240
@dataclass(kw_only=True)
2341
class DGXCloudExecutor(Executor):
@@ -28,32 +46,20 @@ class DGXCloudExecutor(Executor):
2846
via a REST API. It acquires an auth token, identifies the project/cluster,
2947
and launches jobs with a specified command. It can be adapted to meet user
3048
authentication and job-submission requirements on DGX.
31-
32-
Example usage might include specifying the environment variables or secrets
33-
needed to create new distributed training jobs and storing user-specified
34-
configuration (cluster URL, project name, application secrets, etc.).
3549
"""
3650

3751
base_url: str
3852
app_id: str
3953
app_secret: str
4054
project_name: str
41-
job_name: str
4255
container_image: str
4356
nodes: int = 1
4457
gpus_per_node: int = 8
4558
pvcs: list[dict[str, Any]] = field(default_factory=list)
4659
distributed_framework: str = "PyTorch"
4760
custom_spec: dict[str, Any] = field(default_factory=dict)
4861

49-
def __post_init__(self):
50-
self.job_name = self.job_name.replace("_", "-")
51-
5262
def get_auth_token(self) -> Optional[str]:
53-
"""
54-
Retrieves the authorization token from the endpoint. Required for subsequent
55-
calls to create distributed jobs on the DGX platform.
56-
"""
5763
url = f"{self.base_url}/token"
5864
payload = {
5965
"grantType": "app_token",
@@ -72,10 +78,6 @@ def get_auth_token(self) -> Optional[str]:
7278
return auth_token
7379

7480
def get_project_and_cluster_id(self, token: str) -> tuple[Optional[str], Optional[str]]:
75-
"""
76-
Retrieves the project ID and cluster ID by matching the user-provided
77-
project_name to the result from the DGX API. Returns (project_id, cluster_id).
78-
"""
7981
url = f"{self.base_url}/org-unit/projects"
8082
headers = self._default_headers(token=token)
8183
response = requests.get(url, headers=headers)
@@ -90,27 +92,28 @@ def get_project_and_cluster_id(self, token: str) -> tuple[Optional[str], Optiona
9092
break
9193
return project_id, cluster_id
9294

93-
def create_distributed_job(self, token: str, project_id: str, cluster_id: str):
95+
def create_distributed_job(self, token: str, project_id: str, cluster_id: str, name:str, cmd: list[str]):
9496
"""
9597
Creates a distributed PyTorch job using the provided project/cluster IDs.
9698
"""
9799
url = f"{self.base_url}/workloads/distributed"
98100
headers = self._default_headers(token=token)
101+
launch_script = f"""
102+
ln -s {self.job_dir} /nemo_run
103+
cd /nemo_run/code
104+
{" ".join(cmd)}
105+
"""
106+
with open(os.path.join(self.job_dir, "launch_script.sh"), "w+") as f:
107+
f.write(launch_script)
108+
99109
payload = {
100-
"name": self.job_name,
110+
"name": name,
101111
"useGivenNameAsPrefix": True,
102112
"projectId": project_id,
103113
"clusterId": cluster_id,
104114
"spec": {
105-
"command": "echo 'hello' && sleep 60 && echo 'goodbye'",
106-
# "args": f"""
107-
# # ln -s {self.job_dir} /nemo_run
108-
# echo "Hello"
109-
# sleep 600
110-
# echo "Goodbye"
111-
# """,
115+
"command": f"/bin/bash {self.job_dir}/launch_script.sh",
112116
"image": self.container_image,
113-
# "workingDir": "/nemo_run/code",
114117
"distributedFramework": self.distributed_framework,
115118
"minReplicas": self.nodes,
116119
"maxReplicas": self.nodes,
@@ -132,67 +135,69 @@ def create_distributed_job(self, token: str, project_id: str, cluster_id: str):
132135
)
133136
return response
134137

135-
def launch(self, *args, **kwargs) -> tuple[Optional[str], Optional[str]]:
136-
"""
137-
Core entry point to create a token, get the project/cluster, and launch
138-
the distributed job on the DGX platform.
139-
Returns (job_id, handle) to align with the typical Nemo-Run Executor pattern.
140-
"""
138+
def launch(self, name: str, cmd: list[str]) -> tuple[str, str]:
139+
name = name.replace("_", "-") # to meet K8s requirements
141140
token = self.get_auth_token()
142141
if not token:
143-
logger.error("Cannot proceed without auth token")
144-
return None, None
142+
raise RuntimeError("Failed to get auth token")
145143

146144
project_id, cluster_id = self.get_project_and_cluster_id(token)
147145
if not project_id or not cluster_id:
148-
logger.error("Unable to determine project/cluster IDs for job submission")
149-
return None, None
146+
raise RuntimeError("Unable to determine project/cluster IDs for job submission")
150147

151-
resp = self.create_distributed_job(token, project_id, cluster_id)
148+
resp = self.create_distributed_job(token, project_id, cluster_id, name, cmd)
152149
if resp.status_code not in [200, 202]:
153-
logger.error("Failed to create job, status_code=%s", resp.status_code)
154-
return None, None
150+
raise RuntimeError(f"Failed to create job, status_code={resp.status_code}")
155151

156-
# For demonstration, parse out some job ID from the response if available
157-
try:
158-
r_json = resp.json()
159-
job_id = r_json.get("id", "dgx_job_id") # Example ID key
160-
except Exception: # If the response is not valid JSON or no "id"
161-
job_id = "dgx_job_id"
152+
r_json = resp.json()
153+
job_id = r_json["workloadId"]
154+
status = r_json["actualPhase"]
155+
return job_id, status
162156

163-
# Typically in Nemo-Run, "handle" can store information for references
164-
handle = f"dgx://{job_id}"
165-
return job_id, handle
157+
def status(self, job_id: str) -> Optional[DGXCloudState]:
158+
url = f"{self.base_url}/workloads/distributed/{job_id}"
159+
token = self.get_auth_token()
160+
if not token:
161+
logger.error("Failed to retrieve auth token for cancellation request.")
162+
return None
166163

167-
def status(self, app_id: str) -> tuple[Optional[str], Optional[dict]]:
168-
"""
169-
Return the job status from the DGX platform. The app_id might be used
170-
to query the job ID stored at creation time. For demonstration, this is
171-
left abstract, as the API for status queries can be matched to user needs.
172-
"""
173-
logger.debug("Getting status for app_id=%s", app_id) # [1]
174-
# If a specialized endpoint exists, you would call it here, e.g.:
175-
# GET <base_url>/workloads/<job_id>
176-
return None, None
164+
headers = self._default_headers(token=token)
165+
response = requests.get(url, headers=headers)
166+
if response.status_code != 200:
167+
return DGXCloudState("Unknown")
177168

178-
def cancel(self, app_id: str):
179-
"""
180-
Cancels the job on the DGX platform. Typically, you'd parse the job_id
181-
from app_id and call the relevant REST endpoint to delete/cancel the job.
182-
"""
183-
logger.debug("Attempt to cancel job for app_id=%s", app_id)
169+
r_json = response.json()
170+
return DGXCloudState(r_json["actualPhase"])
184171

185-
def logs(self, app_id: str, fallback_path: Optional[str]):
186-
"""
187-
Prints or fetches logs for the job. Typically, you'd parse the job_id
188-
from app_id and query a logs endpoint. Fallback logic can be implemented
189-
if logs must be fetched from a known file path.
190-
"""
172+
def cancel(self, job_id: str):
173+
# Retrieve the authentication token for the REST calls
174+
token = self.get_auth_token()
175+
if not token:
176+
logger.error("Failed to retrieve auth token for cancellation request.")
177+
return
178+
179+
# Build the DELETE request to cancel the job
180+
url = f"{self.base_url}/workloads/distributed/{job_id}/suspend"
181+
headers = self._default_headers(token=token)
182+
183+
response = requests.get(url, headers=headers)
184+
if response.status_code >= 200 and response.status_code < 300:
185+
logger.info(
186+
"Successfully cancelled job %s on DGX with response code %d",
187+
job_id, response.status_code
188+
)
189+
else:
190+
logger.error(
191+
"Failed to cancel job %s, response code=%d, reason=%s",
192+
job_id, response.status_code, response.text
193+
)
194+
195+
@classmethod
196+
def logs(cls: Type["DGXCloudExecutor"], app_id: str, fallback_path: Optional[str]):
197+
logger.warning("Logs not available for DGXCloudExecutor based jobs. Please visit the cluster UI to view the logs.")
191198

192199
def cleanup(self, handle: str):
193-
"""
194-
Performs any necessary cleanup after the job has completed.
195-
"""
200+
...
196201

197202
def assign(
198203
self,
@@ -201,17 +206,14 @@ def assign(
201206
task_id: str,
202207
task_dir: str,
203208
):
204-
"""
205-
Assigns the job to a specific experiment run directory in Nemo-Run.
206-
"""
207209
self.job_name = task_id
208210
self.experiment_dir = exp_dir
209211
self.job_dir = os.path.join(exp_dir, task_dir)
210212
self.experiment_id = exp_id
211213
os.makedirs(self.job_dir, exist_ok=True)
212214
assert any(
213-
map(lambda x: Path(self.job_dir).relative_to(Path(x["path"])), self.pvcs)
214-
), f"Need to specify atleast one PVC matching {self.job_dir}"
215+
map(lambda x: os.path.commonpath([os.path.abspath(x["path"]), os.path.abspath(self.job_dir)]) == os.path.abspath(x["path"]), self.pvcs)
216+
), f"Need to specify atleast one PVC containing {self.job_dir}.\nTo update job dir to a PVC path, you can set the NEMORUN_HOME env var."
215217

216218
def package(self, packager: Packager, job_name: str):
217219
assert self.experiment_id, "Executor not assigned to an experiment."
@@ -242,10 +244,6 @@ def package(self, packager: Packager, job_name: str):
242244
)
243245

244246
def macro_values(self) -> Optional[ExecutorMacros]:
245-
"""
246-
Returns environment macros for distributed training. Not strictly used in this
247-
example, but can configure advanced key-value pairs for the job environment.
248-
"""
249247
return None
250248

251249
def _default_headers(self, token: Optional[str] = None) -> dict:

src/nemo_run/run/experiment.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
get_type_namespace,
5050
)
5151
from nemo_run.core.execution.base import Executor
52+
from nemo_run.core.execution.dgxcloud import DGXCloudExecutor
5253
from nemo_run.core.execution.docker import DockerExecutor
5354
from nemo_run.core.execution.local import LocalExecutor
5455
from nemo_run.core.execution.skypilot import SkypilotExecutor
@@ -182,8 +183,8 @@ class Experiment(ConfigurableMixin):
182183
nemo experiment logs {exp_id} 0
183184
nemo experiment cancel {exp_id} 0
184185
"""
185-
_PARALLEL_SUPPORTED_EXECUTORS = (SlurmExecutor, LocalExecutor, SkypilotExecutor, DockerExecutor)
186-
_DETACH_SUPPORTED_EXECUTORS = (SlurmExecutor, SkypilotExecutor)
186+
_PARALLEL_SUPPORTED_EXECUTORS = (SlurmExecutor, LocalExecutor, SkypilotExecutor, DockerExecutor, DGXCloudExecutor)
187+
_DETACH_SUPPORTED_EXECUTORS = (SlurmExecutor, SkypilotExecutor, DGXCloudExecutor)
187188
_DEPENDENCY_SUPPORTED_EXECUTORS = (SlurmExecutor,)
188189
_RUNNER_DEPENDENT_EXECUTORS = (LocalExecutor,)
189190
_CONFIG_FILE = "_CONFIG"

src/nemo_run/run/torchx_backend/schedulers/api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torchx.specs import AppDef, AppDryRunInfo
1919

2020
from nemo_run.core.execution.base import Executor
21+
from nemo_run.core.execution.dgxcloud import DGXCloudExecutor
2122
from nemo_run.core.execution.docker import DockerExecutor
2223
from nemo_run.core.execution.local import LocalExecutor
2324
from nemo_run.core.execution.skypilot import SkypilotExecutor
@@ -28,13 +29,15 @@
2829
SkypilotExecutor: "skypilot",
2930
LocalExecutor: "local_persistent",
3031
DockerExecutor: "docker_persistent",
32+
DGXCloudExecutor: "dgx_cloud",
3133
}
3234

3335
REVERSE_EXECUTOR_MAPPING: dict[str, Type[Executor]] = {
3436
"slurm_tunnel": SlurmExecutor,
3537
"skypilot": SkypilotExecutor,
3638
"local_persistent": LocalExecutor,
3739
"docker_persistent": DockerExecutor,
40+
"dgx_cloud": DGXCloudExecutor,
3841
}
3942

4043

0 commit comments

Comments
 (0)