Skip to content

Commit 61bb965

Browse files
committed
Fix formatting
Signed-off-by: Hemil Desai <[email protected]>
1 parent e4fecfd commit 61bb965

File tree

3 files changed

+40
-24
lines changed

3 files changed

+40
-24
lines changed

src/nemo_run/core/execution/dgxcloud.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import subprocess
55
from dataclasses import dataclass, field
6+
from enum import Enum
67
from pathlib import Path
78
from typing import Any, Optional, Type
89

@@ -18,7 +19,6 @@
1819

1920
logger = logging.getLogger(__name__)
2021

21-
from enum import Enum
2222

2323
class DGXCloudState(Enum):
2424
CREATING = "Creating"
@@ -92,7 +92,9 @@ def get_project_and_cluster_id(self, token: str) -> tuple[Optional[str], Optiona
9292
break
9393
return project_id, cluster_id
9494

95-
def create_distributed_job(self, token: str, project_id: str, cluster_id: str, name:str, cmd: list[str]):
95+
def create_distributed_job(
96+
self, token: str, project_id: str, cluster_id: str, name: str, cmd: list[str]
97+
):
9698
"""
9799
Creates a distributed PyTorch job using the provided project/cluster IDs.
98100
"""
@@ -136,7 +138,7 @@ def create_distributed_job(self, token: str, project_id: str, cluster_id: str, n
136138
return response
137139

138140
def launch(self, name: str, cmd: list[str]) -> tuple[str, str]:
139-
name = name.replace("_", "-") # to meet K8s requirements
141+
name = name.replace("_", "-") # to meet K8s requirements
140142
token = self.get_auth_token()
141143
if not token:
142144
raise RuntimeError("Failed to get auth token")
@@ -184,20 +186,24 @@ def cancel(self, job_id: str):
184186
if response.status_code >= 200 and response.status_code < 300:
185187
logger.info(
186188
"Successfully cancelled job %s on DGX with response code %d",
187-
job_id, response.status_code
189+
job_id,
190+
response.status_code,
188191
)
189192
else:
190193
logger.error(
191194
"Failed to cancel job %s, response code=%d, reason=%s",
192-
job_id, response.status_code, response.text
195+
job_id,
196+
response.status_code,
197+
response.text,
193198
)
194199

195200
@classmethod
196201
def logs(cls: Type["DGXCloudExecutor"], app_id: str, fallback_path: Optional[str]):
197-
logger.warning("Logs not available for DGXCloudExecutor based jobs. Please visit the cluster UI to view the logs.")
202+
logger.warning(
203+
"Logs not available for DGXCloudExecutor based jobs. Please visit the cluster UI to view the logs."
204+
)
198205

199-
def cleanup(self, handle: str):
200-
...
206+
def cleanup(self, handle: str): ...
201207

202208
def assign(
203209
self,
@@ -212,7 +218,13 @@ def assign(
212218
self.experiment_id = exp_id
213219
os.makedirs(self.job_dir, exist_ok=True)
214220
assert any(
215-
map(lambda x: os.path.commonpath([os.path.abspath(x["path"]), os.path.abspath(self.job_dir)]) == os.path.abspath(x["path"]), self.pvcs)
221+
map(
222+
lambda x: os.path.commonpath(
223+
[os.path.abspath(x["path"]), os.path.abspath(self.job_dir)]
224+
)
225+
== os.path.abspath(x["path"]),
226+
self.pvcs,
227+
)
216228
), f"Need to specify atleast one PVC containing {self.job_dir}.\nTo update job dir to a PVC path, you can set the NEMORUN_HOME env var."
217229

218230
def package(self, packager: Packager, job_name: str):

src/nemo_run/run/experiment.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,13 @@ class Experiment(ConfigurableMixin):
183183
nemo experiment logs {exp_id} 0
184184
nemo experiment cancel {exp_id} 0
185185
"""
186-
_PARALLEL_SUPPORTED_EXECUTORS = (SlurmExecutor, LocalExecutor, SkypilotExecutor, DockerExecutor, DGXCloudExecutor)
186+
_PARALLEL_SUPPORTED_EXECUTORS = (
187+
SlurmExecutor,
188+
LocalExecutor,
189+
SkypilotExecutor,
190+
DockerExecutor,
191+
DGXCloudExecutor,
192+
)
187193
_DETACH_SUPPORTED_EXECUTORS = (SlurmExecutor, SkypilotExecutor, DGXCloudExecutor)
188194
_DEPENDENCY_SUPPORTED_EXECUTORS = (SlurmExecutor,)
189195
_RUNNER_DEPENDENT_EXECUTORS = (LocalExecutor,)

src/nemo_run/run/torchx_backend/schedulers/dgxcloud.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,14 @@ class DGXRequest:
6060
Wrapper around the torchx AppDef and the DGX executor.
6161
This object is used to store job submission info for the scheduler.
6262
"""
63+
6364
app: AppDef
6465
executor: DGXCloudExecutor
6566
cmd: list[str]
6667
name: str
6768

6869

69-
class DGXCloudScheduler(SchedulerMixin, Scheduler[dict[str, str]]): # type: ignore
70+
class DGXCloudScheduler(SchedulerMixin, Scheduler[dict[str, str]]): # type: ignore
7071
def __init__(self, session_name: str) -> None:
7172
super().__init__("dgx", session_name)
7273

@@ -76,11 +77,11 @@ def _run_opts(self) -> runopts:
7677
"job_dir",
7778
type_=str,
7879
help="The directory to place the job code and outputs."
79-
" The directory must not exist and will be created.",
80+
" The directory must not exist and will be created.",
8081
)
8182
return opts
8283

83-
def _submit_dryrun( # type: ignore
84+
def _submit_dryrun( # type: ignore
8485
self,
8586
app: AppDef,
8687
cfg: Executor,
@@ -100,7 +101,7 @@ def _submit_dryrun( # type: ignore
100101
return AppDryRunInfo(
101102
DGXRequest(app=app, executor=executor, cmd=cmd, name=role.name),
102103
# Minimal function to show the config, if any
103-
lambda req: f"DGX job for app: {req.app.name}, cmd: {' '.join(cmd)}, executor: {executor}"
104+
lambda req: f"DGX job for app: {req.app.name}, cmd: {' '.join(cmd)}, executor: {executor}",
104105
)
105106

106107
def schedule(self, dryrun_info: AppDryRunInfo[DGXRequest]) -> str:
@@ -148,20 +149,15 @@ def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
148149
RoleStatus(
149150
role_name,
150151
replicas=[
151-
ReplicaStatus(
152-
id=0,
153-
role=role_name,
154-
state=AppState.SUBMITTED,
155-
hostname=""
156-
)
152+
ReplicaStatus(id=0, role=role_name, state=AppState.SUBMITTED, hostname="")
157153
],
158154
)
159155
]
160156

161157
if not job_info:
162158
return None
163159

164-
executor: DGXCloudExecutor = job_info.get("executor", None) # type: ignore
160+
executor: DGXCloudExecutor = job_info.get("executor", None) # type: ignore
165161
if not executor:
166162
return None
167163

@@ -175,7 +171,7 @@ def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
175171
roles_statuses=roles_statuses,
176172
state=app_state,
177173
msg="",
178-
ui_url=f"{executor.base_url}/workloads/distributed/{job_id}"
174+
ui_url=f"{executor.base_url}/workloads/distributed/{job_id}",
179175
)
180176

181177
def _cancel_existing(self, app_id: str) -> None:
@@ -185,7 +181,7 @@ def _cancel_existing(self, app_id: str) -> None:
185181
stored_data = _get_job_dirs()
186182
job_info = stored_data.get(app_id)
187183
_, _, job_id = app_id.split("___")
188-
executor: DGXCloudExecutor = job_info.get("executor", None) # type: ignore
184+
executor: DGXCloudExecutor = job_info.get("executor", None) # type: ignore
189185
if not executor:
190186
return None
191187
executor.delete(job_id)
@@ -219,7 +215,9 @@ def _save_job_dir(app_id: str, job_status: str, executor: DGXCloudExecutor) -> N
219215

220216
app = {
221217
"job_status": job_status,
222-
"executor": serializer.serialize(fdl_dc.convert_dataclasses_to_configs(executor, allow_post_init=True)),
218+
"executor": serializer.serialize(
219+
fdl_dc.convert_dataclasses_to_configs(executor, allow_post_init=True)
220+
),
223221
}
224222
original_apps[app_id] = app
225223

0 commit comments

Comments
 (0)