44import subprocess
55from dataclasses import dataclass , field
66from pathlib import Path
7- from typing import Any , Optional
7+ from typing import Any , Optional , Type
88
99import requests
1010from invoke .context import Context
1818
1919logger = logging .getLogger (__name__ )
2020
21+ from enum import Enum
22+
23+ class DGXCloudState (Enum ):
24+ CREATING = "Creating"
25+ INITIALIZING = "Initializing"
26+ RESUMING = "Resuming"
27+ PENDING = "Pending"
28+ DELETING = "Deleting"
29+ RUNNING = "Running"
30+ UPDATING = "Updating"
31+ STOPPED = "Stopped"
32+ STOPPING = "Stopping"
33+ DEGRADED = "Degraded"
34+ FAILED = "Failed"
35+ COMPLETED = "Completed"
36+ TERMINATING = "Terminating"
37+ UNKNOWN = "Unknown"
38+
2139
2240@dataclass (kw_only = True )
2341class DGXCloudExecutor (Executor ):
@@ -28,32 +46,20 @@ class DGXCloudExecutor(Executor):
2846 via a REST API. It acquires an auth token, identifies the project/cluster,
2947 and launches jobs with a specified command. It can be adapted to meet user
3048 authentication and job-submission requirements on DGX.
31-
32- Example usage might include specifying the environment variables or secrets
33- needed to create new distributed training jobs and storing user-specified
34- configuration (cluster URL, project name, application secrets, etc.).
3549 """
3650
3751 base_url : str
3852 app_id : str
3953 app_secret : str
4054 project_name : str
41- job_name : str
4255 container_image : str
4356 nodes : int = 1
4457 gpus_per_node : int = 8
4558 pvcs : list [dict [str , Any ]] = field (default_factory = list )
4659 distributed_framework : str = "PyTorch"
4760 custom_spec : dict [str , Any ] = field (default_factory = dict )
4861
49- def __post_init__ (self ):
50- self .job_name = self .job_name .replace ("_" , "-" )
51-
5262 def get_auth_token (self ) -> Optional [str ]:
53- """
54- Retrieves the authorization token from the endpoint. Required for subsequent
55- calls to create distributed jobs on the DGX platform.
56- """
5763 url = f"{ self .base_url } /token"
5864 payload = {
5965 "grantType" : "app_token" ,
@@ -72,10 +78,6 @@ def get_auth_token(self) -> Optional[str]:
7278 return auth_token
7379
7480 def get_project_and_cluster_id (self , token : str ) -> tuple [Optional [str ], Optional [str ]]:
75- """
76- Retrieves the project ID and cluster ID by matching the user-provided
77- project_name to the result from the DGX API. Returns (project_id, cluster_id).
78- """
7981 url = f"{ self .base_url } /org-unit/projects"
8082 headers = self ._default_headers (token = token )
8183 response = requests .get (url , headers = headers )
@@ -90,27 +92,28 @@ def get_project_and_cluster_id(self, token: str) -> tuple[Optional[str], Optiona
9092 break
9193 return project_id , cluster_id
9294
93- def create_distributed_job (self , token : str , project_id : str , cluster_id : str ):
95+ def create_distributed_job (self , token : str , project_id : str , cluster_id : str , name : str , cmd : list [ str ] ):
9496 """
9597 Creates a distributed PyTorch job using the provided project/cluster IDs.
9698 """
9799 url = f"{ self .base_url } /workloads/distributed"
98100 headers = self ._default_headers (token = token )
101+ launch_script = f"""
102+ ln -s { self .job_dir } /nemo_run
103+ cd /nemo_run/code
104+ { " " .join (cmd )}
105+ """
106+ with open (os .path .join (self .job_dir , "launch_script.sh" ), "w+" ) as f :
107+ f .write (launch_script )
108+
99109 payload = {
100- "name" : self . job_name ,
110+ "name" : name ,
101111 "useGivenNameAsPrefix" : True ,
102112 "projectId" : project_id ,
103113 "clusterId" : cluster_id ,
104114 "spec" : {
105- "command" : "echo 'hello' && sleep 60 && echo 'goodbye'" ,
106- # "args": f"""
107- # # ln -s {self.job_dir} /nemo_run
108- # echo "Hello"
109- # sleep 600
110- # echo "Goodbye"
111- # """,
115+ "command" : f"/bin/bash { self .job_dir } /launch_script.sh" ,
112116 "image" : self .container_image ,
113- # "workingDir": "/nemo_run/code",
114117 "distributedFramework" : self .distributed_framework ,
115118 "minReplicas" : self .nodes ,
116119 "maxReplicas" : self .nodes ,
@@ -132,67 +135,69 @@ def create_distributed_job(self, token: str, project_id: str, cluster_id: str):
132135 )
133136 return response
134137
135- def launch (self , * args , ** kwargs ) -> tuple [Optional [str ], Optional [str ]]:
136- """
137- Core entry point to create a token, get the project/cluster, and launch
138- the distributed job on the DGX platform.
139- Returns (job_id, handle) to align with the typical Nemo-Run Executor pattern.
140- """
138+ def launch (self , name : str , cmd : list [str ]) -> tuple [str , str ]:
139+ name = name .replace ("_" , "-" ) # to meet K8s requirements
141140 token = self .get_auth_token ()
142141 if not token :
143- logger .error ("Cannot proceed without auth token" )
144- return None , None
142+ raise RuntimeError ("Failed to get auth token" )
145143
146144 project_id , cluster_id = self .get_project_and_cluster_id (token )
147145 if not project_id or not cluster_id :
148- logger .error ("Unable to determine project/cluster IDs for job submission" )
149- return None , None
146+ raise RuntimeError ("Unable to determine project/cluster IDs for job submission" )
150147
151- resp = self .create_distributed_job (token , project_id , cluster_id )
148+ resp = self .create_distributed_job (token , project_id , cluster_id , name , cmd )
152149 if resp .status_code not in [200 , 202 ]:
153- logger .error ("Failed to create job, status_code=%s" , resp .status_code )
154- return None , None
150+ raise RuntimeError (f"Failed to create job, status_code={ resp .status_code } " )
155151
156- # For demonstration, parse out some job ID from the response if available
157- try :
158- r_json = resp .json ()
159- job_id = r_json .get ("id" , "dgx_job_id" ) # Example ID key
160- except Exception : # If the response is not valid JSON or no "id"
161- job_id = "dgx_job_id"
152+ r_json = resp .json ()
153+ job_id = r_json ["workloadId" ]
154+ status = r_json ["actualPhase" ]
155+ return job_id , status
162156
163- # Typically in Nemo-Run, "handle" can store information for references
164- handle = f"dgx://{ job_id } "
165- return job_id , handle
157+ def status (self , job_id : str ) -> Optional [DGXCloudState ]:
158+ url = f"{ self .base_url } /workloads/distributed/{ job_id } "
159+ token = self .get_auth_token ()
160+ if not token :
161+ logger .error ("Failed to retrieve auth token for cancellation request." )
162+ return None
166163
167- def status (self , app_id : str ) -> tuple [Optional [str ], Optional [dict ]]:
168- """
169- Return the job status from the DGX platform. The app_id might be used
170- to query the job ID stored at creation time. For demonstration, this is
171- left abstract, as the API for status queries can be matched to user needs.
172- """
173- logger .debug ("Getting status for app_id=%s" , app_id ) # [1]
174- # If a specialized endpoint exists, you would call it here, e.g.:
175- # GET <base_url>/workloads/<job_id>
176- return None , None
164+ headers = self ._default_headers (token = token )
165+ response = requests .get (url , headers = headers )
166+ if response .status_code != 200 :
167+ return DGXCloudState ("Unknown" )
177168
178- def cancel (self , app_id : str ):
179- """
180- Cancels the job on the DGX platform. Typically, you'd parse the job_id
181- from app_id and call the relevant REST endpoint to delete/cancel the job.
182- """
183- logger .debug ("Attempt to cancel job for app_id=%s" , app_id )
169+ r_json = response .json ()
170+ return DGXCloudState (r_json ["actualPhase" ])
184171
185- def logs (self , app_id : str , fallback_path : Optional [str ]):
186- """
187- Prints or fetches logs for the job. Typically, you'd parse the job_id
188- from app_id and query a logs endpoint. Fallback logic can be implemented
189- if logs must be fetched from a known file path.
190- """
172+ def cancel (self , job_id : str ):
173+ # Retrieve the authentication token for the REST calls
174+ token = self .get_auth_token ()
175+ if not token :
176+ logger .error ("Failed to retrieve auth token for cancellation request." )
177+ return
178+
179+ # Build the DELETE request to cancel the job
180+ url = f"{ self .base_url } /workloads/distributed/{ job_id } /suspend"
181+ headers = self ._default_headers (token = token )
182+
183+ response = requests .get (url , headers = headers )
184+ if response .status_code >= 200 and response .status_code < 300 :
185+ logger .info (
186+ "Successfully cancelled job %s on DGX with response code %d" ,
187+ job_id , response .status_code
188+ )
189+ else :
190+ logger .error (
191+ "Failed to cancel job %s, response code=%d, reason=%s" ,
192+ job_id , response .status_code , response .text
193+ )
194+
195+ @classmethod
196+ def logs (cls : Type ["DGXCloudExecutor" ], app_id : str , fallback_path : Optional [str ]):
197+ logger .warning ("Logs not available for DGXCloudExecutor based jobs. Please visit the cluster UI to view the logs." )
191198
192199 def cleanup (self , handle : str ):
193- """
194- Performs any necessary cleanup after the job has completed.
195- """
200+ ...
196201
197202 def assign (
198203 self ,
@@ -201,17 +206,14 @@ def assign(
201206 task_id : str ,
202207 task_dir : str ,
203208 ):
204- """
205- Assigns the job to a specific experiment run directory in Nemo-Run.
206- """
207209 self .job_name = task_id
208210 self .experiment_dir = exp_dir
209211 self .job_dir = os .path .join (exp_dir , task_dir )
210212 self .experiment_id = exp_id
211213 os .makedirs (self .job_dir , exist_ok = True )
212214 assert any (
213- map (lambda x : Path ( self .job_dir ). relative_to ( Path ( x ["path" ]) ), self .pvcs )
214- ), f"Need to specify atleast one PVC matching { self .job_dir } "
215+ map (lambda x : os . path . commonpath ([ os . path . abspath ( x [ "path" ]), os . path . abspath ( self .job_dir )]) == os . path . abspath ( x ["path" ]), self .pvcs )
216+ ), f"Need to specify atleast one PVC containing { self .job_dir } . \n To update job dir to a PVC path, you can set the NEMORUN_HOME env var. "
215217
216218 def package (self , packager : Packager , job_name : str ):
217219 assert self .experiment_id , "Executor not assigned to an experiment."
@@ -242,10 +244,6 @@ def package(self, packager: Packager, job_name: str):
242244 )
243245
244246 def macro_values (self ) -> Optional [ExecutorMacros ]:
245- """
246- Returns environment macros for distributed training. Not strictly used in this
247- example, but can configure advanced key-value pairs for the job environment.
248- """
249247 return None
250248
251249 def _default_headers (self , token : Optional [str ] = None ) -> dict :
0 commit comments