Skip to content

Commit a61734b

Browse files
authored
Use thread pool for status, run methods inside experiment + other fixes (#295)
* Add thread pool to get status of jobs inside experiment Signed-off-by: Hemil Desai <[email protected]> * Add thread pools to experiment run Signed-off-by: Hemil Desai <[email protected]> * fixes Signed-off-by: Hemil Desai <[email protected]> * fix Signed-off-by: Hemil Desai <[email protected]> * fix Signed-off-by: Hemil Desai <[email protected]> * fix Signed-off-by: Hemil Desai <[email protected]> * fix Signed-off-by: Hemil Desai <[email protected]> --------- Signed-off-by: Hemil Desai <[email protected]>
1 parent 6b01546 commit a61734b

File tree

9 files changed

+316
-34
lines changed

9 files changed

+316
-34
lines changed

nemo_run/core/execution/slurm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,10 @@ def package_configs(self, *cfgs: tuple[str, str]) -> list[str]:
590590
def package(self, packager: Packager, job_name: str):
591591
assert self.experiment_id, "Executor not assigned to an experiment."
592592

593-
if job_name in self.tunnel.packaging_jobs and not packager.symlink_from_remote_dir:
593+
if (
594+
get_packaging_job_key(self.experiment_id, job_name) in self.tunnel.packaging_jobs
595+
and not packager.symlink_from_remote_dir
596+
):
594597
logger.info(
595598
f"Packaging for job {job_name} in tunnel {self.tunnel.key} already done. Skipping subsequent packagings.\n"
596599
"This may cause issues if you have multiple tasks with the same name but different packagers, as only the first packager will be used."

nemo_run/run/experiment.py

Lines changed: 82 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,8 @@
3535
from rich.console import Group
3636
from rich.live import Live
3737
from rich.panel import Panel
38-
from rich.progress import BarColumn, Progress, SpinnerColumn
38+
from rich.progress import BarColumn, Progress, SpinnerColumn, TaskID, TimeElapsedColumn
3939
from rich.progress import Task as RichTask
40-
from rich.progress import TaskID, TimeElapsedColumn
4140
from rich.syntax import Syntax
4241
from torchx.specs.api import AppState
4342

@@ -303,6 +302,9 @@ def __init__(
303302
base_dir: str | None = None,
304303
clean_mode: bool = False,
305304
enable_goodbye_message: bool = True,
305+
threadpool_workers: int = 16,
306+
skip_status_at_exit: bool = False,
307+
serialize_metadata_for_scripts: bool = True,
306308
) -> None:
307309
"""
308310
Initializes an experiment run by creating its metadata directory and saving the experiment config.
@@ -328,6 +330,9 @@ def __init__(
328330
self._title = title
329331
self._id = id or f"{title}_{int(time.time())}"
330332
self._enable_goodbye_message = enable_goodbye_message
333+
self._threadpool_workers = threadpool_workers
334+
self._skip_status_at_exit = skip_status_at_exit
335+
self._serialize_metadata_for_scripts = serialize_metadata_for_scripts
331336

332337
base_dir = str(base_dir or get_nemorun_home())
333338
self._exp_dir = os.path.join(base_dir, "experiments", title, self._id)
@@ -359,6 +364,10 @@ def to_config(self) -> Config:
359364
executor=self.executor.to_config(),
360365
log_level=self.log_level,
361366
clean_mode=self.clean_mode,
367+
threadpool_workers=self._threadpool_workers,
368+
enable_goodbye_message=self._enable_goodbye_message,
369+
skip_status_at_exit=self._skip_status_at_exit,
370+
serialize_metadata_for_scripts=self._serialize_metadata_for_scripts,
362371
)
363372

364373
def _save_experiment(self, exist_ok: bool = False):
@@ -422,8 +431,9 @@ def _load_jobs(self) -> list[Job | JobGroup]:
422431

423432
def _prepare(self, exist_ok: bool = False):
424433
self._save_experiment(exist_ok=exist_ok)
434+
425435
for job in self.jobs:
426-
job.prepare()
436+
job.prepare(serialize_metadata_for_scripts=self._serialize_metadata_for_scripts)
427437

428438
self._save_jobs()
429439

@@ -769,7 +779,15 @@ def _run_dag(self, detach: bool, tail_logs: bool, executors: set[Executor]):
769779
self.detach = detach
770780

771781
for level in order:
772-
for _, node in enumerate(level):
782+
# Launch jobs in this level concurrently since they are independent
783+
784+
def _set_context(ctx: contextvars.Context):
785+
for var, value in ctx.items():
786+
var.set(value)
787+
788+
ctx = contextvars.copy_context()
789+
790+
def _launch(node: str):
773791
job: Job | JobGroup = job_map[node]
774792
self.console.log(f"[bold cyan]Launching job {job.id} for experiment {self._title}")
775793
if tail_logs:
@@ -787,14 +805,24 @@ def _run_dag(self, detach: bool, tail_logs: bool, executors: set[Executor]):
787805
deps.append(handle)
788806

789807
job.executor.dependencies = deps # type: ignore
808+
790809
job.launch(wait=False, runner=self._runner)
810+
return job
791811

792812
except Exception as e:
793813
self.console.log(f"Error running job {job.id}: {e}")
794814
raise e
795815

816+
launched_jobs: list[Job | JobGroup] = []
817+
with ThreadPoolExecutor(
818+
initializer=_set_context, initargs=(ctx,), max_workers=self._threadpool_workers
819+
) as pool:
820+
futures = [pool.submit(_launch, node) for node in level]
821+
for future in as_completed(futures):
822+
launched_jobs.append(future.result())
823+
796824
if wait:
797-
self._wait_for_jobs(jobs=[job_map[node] for node in level])
825+
self._wait_for_jobs(jobs=launched_jobs)
798826

799827
self._save_jobs()
800828
self._launched = any(map(lambda job: job.launched, self.jobs))
@@ -840,7 +868,21 @@ def set_context(context: contextvars.Context):
840868
finally:
841869
job.cleanup()
842870

843-
def status(self, return_dict: bool = False) -> Optional[dict[str, str]]:
871+
def _initialize_tunnels(self, extract_from_executors: bool = False):
872+
if extract_from_executors:
873+
for job in self.jobs:
874+
if (
875+
isinstance(job.executor, SlurmExecutor)
876+
and job.executor.tunnel.key not in self.tunnels
877+
):
878+
self.tunnels[job.executor.tunnel.key] = job.executor.tunnel
879+
880+
for tunnel in self.tunnels.values():
881+
if isinstance(tunnel, SSHTunnel):
882+
tunnel.connect()
883+
assert tunnel.session, f"SSH tunnel {tunnel.key} failed to connect."
884+
885+
def status(self, return_dict: bool = False) -> Optional[dict[str, dict[str, str]]]:
844886
"""
845887
Prints a table specifying the status of all tasks.
846888
@@ -880,6 +922,7 @@ def _get_job_info_and_dict(
880922
"status": job.status(runner=self._runner),
881923
"executor": job.executor.info(),
882924
"job_id": app_id,
925+
"handle": job.handle,
883926
"local_dir": job.executor.job_dir,
884927
}
885928

@@ -902,13 +945,34 @@ def _get_job_info_and_dict(
902945
job_info.extend(directory_info)
903946
return job_info, job_dict
904947

948+
self._initialize_tunnels(extract_from_executors=True)
905949
try:
906950
result_dict = {}
907-
job_infos = []
908-
for i, job in enumerate(self.jobs):
909-
job_info, job_dict = _get_job_info_and_dict(i, job)
910-
job_infos.append(Group(*job_info))
911-
result_dict[job.id] = job_dict
951+
job_infos: list[Group | None] = [None] * len(self.jobs)
952+
953+
# Parallelize IO-bound status retrieval across jobs
954+
def _collect(arg):
955+
idx, job = arg
956+
job_info, job_dict = _get_job_info_and_dict(idx, job)
957+
return idx, job.id, job_info, job_dict
958+
959+
# Propagate context variables to worker threads so helpers that rely on them keep working
960+
def _set_context(ctx: contextvars.Context):
961+
for var, value in ctx.items():
962+
var.set(value)
963+
964+
ctx = contextvars.copy_context()
965+
with ThreadPoolExecutor(
966+
initializer=_set_context, initargs=(ctx,), max_workers=self._threadpool_workers
967+
) as pool:
968+
futures = [pool.submit(_collect, (idx, job)) for idx, job in enumerate(self.jobs)]
969+
for future in as_completed(futures):
970+
idx, job_id, job_info, job_dict = future.result()
971+
job_infos[idx] = Group(*job_info)
972+
result_dict[job_id] = job_dict
973+
974+
# Remove potential None slots (should not occur)
975+
job_infos = [ji for ji in job_infos if ji is not None]
912976

913977
if return_dict:
914978
return result_dict
@@ -1142,7 +1206,7 @@ def __exit__(self, exc_type, exc_value, tb):
11421206
"Ephemeral logs and artifacts may be lost.",
11431207
)
11441208

1145-
if self._launched:
1209+
if self._launched and not self._skip_status_at_exit:
11461210
self.status()
11471211
return
11481212

@@ -1151,20 +1215,23 @@ def __exit__(self, exc_type, exc_value, tb):
11511215
self.console.rule(
11521216
f"[bold magenta]Direct run Experiment {self._id}",
11531217
)
1154-
self.status()
1218+
if not self._skip_status_at_exit:
1219+
self.status()
11551220
return
11561221

11571222
if hasattr(self, "_waited") and self._waited:
11581223
self.console.rule(
11591224
f"[bold magenta]Done waiting for Experiment {self._id}",
11601225
)
1161-
self.status()
1226+
if not self._skip_status_at_exit:
1227+
self.status()
11621228
return
11631229

11641230
self.console.rule(
11651231
f"[bold magenta]Waiting for Experiment {self._id} to finish",
11661232
)
1167-
self.status()
1233+
if not self._skip_status_at_exit:
1234+
self.status()
11681235

11691236
self._wait_for_jobs(jobs=self.jobs)
11701237
finally:

nemo_run/run/job.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from dataclasses import dataclass, field
44
from typing import Optional, Union, cast
55

6-
from torchx.specs.api import AppDef, AppState, is_terminal
6+
from torchx.specs.api import AppDef, AppDryRunInfo, AppState, is_terminal
77

88
import nemo_run.exceptions
99
from nemo_run.config import Config, ConfigurableMixin, Partial, Script
@@ -62,6 +62,10 @@ class Job(ConfigurableMixin):
6262
plugins: Optional[list[ExperimentPlugin]] = None
6363
tail_logs: bool = False
6464
dependencies: list[str] = field(default_factory=list)
65+
name: str = ""
66+
67+
def __post_init__(self):
68+
self._dryrun_info: Optional[AppDryRunInfo] = None
6569

6670
def serialize(self) -> tuple[str, str]:
6771
cfg = self.to_config()
@@ -92,10 +96,14 @@ def logs(self, runner: Runner, regex: str | None = None):
9296
regex=regex,
9397
)
9498

95-
def prepare(self):
99+
def prepare(self, serialize_metadata_for_scripts: bool = True):
96100
self.executor.create_job_dir()
97101
self._executable = package(
98-
self.id, self.task, executor=self.executor, serialize_to_file=True
102+
self.id,
103+
self.task,
104+
executor=self.executor,
105+
serialize_to_file=True,
106+
serialize_metadata_for_scripts=serialize_metadata_for_scripts,
99107
)
100108

101109
def launch(
@@ -120,7 +128,7 @@ def launch(
120128
return
121129

122130
if dryrun:
123-
launch(
131+
_, dryrun_info = launch(
124132
executable=self._executable,
125133
executor_name=executor_str,
126134
executor=self.executor,
@@ -130,6 +138,7 @@ def launch(
130138
log=self.tail_logs,
131139
runner=runner,
132140
)
141+
self._dryrun_info = dryrun_info
133142
return
134143

135144
self.handle, status = launch(
@@ -140,6 +149,7 @@ def launch(
140149
wait=wait,
141150
log=self.tail_logs,
142151
runner=runner,
152+
dryrun_info=self._dryrun_info,
143153
)
144154
self.state = status.state if status else AppState.UNKNOWN
145155
self.launched = True
@@ -223,6 +233,7 @@ class JobGroup(ConfigurableMixin):
223233
plugins: Optional[list[ExperimentPlugin]] = None
224234
tail_logs: bool = False
225235
dependencies: list[str] = field(default_factory=list)
236+
name: str = ""
226237

227238
def __post_init__(self):
228239
executors = [self.executors] if isinstance(self.executors, Executor) else self.executors
@@ -252,6 +263,8 @@ def __post_init__(self):
252263
if len(executors) == 1:
253264
self.executors = executors * len(self.tasks)
254265

266+
self._dryrun_info: Optional[AppDryRunInfo] = None
267+
255268
@property
256269
def state(self) -> AppState:
257270
if not self.launched or not self.handles:
@@ -307,7 +320,7 @@ def logs(self, runner: Runner, regex: str | None = None):
307320
regex=regex,
308321
)
309322

310-
def prepare(self):
323+
def prepare(self, serialize_metadata_for_scripts: bool = True):
311324
self.executor.create_job_dir()
312325
self._executables: list[tuple[AppDef, Executor]] = []
313326
for i, task in enumerate(self.tasks):
@@ -318,6 +331,7 @@ def prepare(self):
318331
task,
319332
executor=executor,
320333
serialize_to_file=True,
334+
serialize_metadata_for_scripts=serialize_metadata_for_scripts,
321335
)
322336
self._executables.append((executable, executor))
323337

@@ -346,7 +360,7 @@ def launch(
346360
executor_str = get_executor_str(executor)
347361

348362
if dryrun:
349-
launch(
363+
_, dryrun_info = launch(
350364
executable=executable,
351365
executor_name=executor_str,
352366
executor=executor,
@@ -356,6 +370,7 @@ def launch(
356370
log=self.tail_logs,
357371
runner=runner,
358372
)
373+
self._dryrun_info = dryrun_info
359374
else:
360375
handle, status = launch(
361376
executable=executable,
@@ -365,6 +380,7 @@ def launch(
365380
wait=wait,
366381
log=self.tail_logs,
367382
runner=runner,
383+
dryrun_info=self._dryrun_info,
368384
)
369385
self.handles.append(handle)
370386
self.states.append(status.state if status else AppState.UNKNOWN)

nemo_run/run/torchx_backend/launcher.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def launch(
4545
parent_run_id: Optional[str] = None,
4646
runner: Runner | None = None,
4747
log_dryrun: bool = ...,
48+
dryrun_info: specs.AppDryRunInfo | None = None,
4849
) -> tuple[None, None]: ...
4950

5051

@@ -59,6 +60,7 @@ def launch(
5960
parent_run_id: Optional[str] = None,
6061
runner: Runner | None = None,
6162
log_dryrun: bool = ...,
63+
dryrun_info: specs.AppDryRunInfo | None = None,
6264
) -> tuple[str, specs.AppStatus]: ...
6365

6466

@@ -73,6 +75,7 @@ def launch(
7375
parent_run_id: Optional[str] = None,
7476
runner: Runner | None = None,
7577
log_dryrun: bool = False,
78+
dryrun_info: specs.AppDryRunInfo | None = None,
7679
) -> tuple[str | None, specs.AppStatus | None]: ...
7780

7881

@@ -86,7 +89,8 @@ def launch(
8689
parent_run_id: Optional[str] = None,
8790
runner: Runner | None = None,
8891
log_dryrun: bool = False,
89-
) -> tuple[str | None, specs.AppStatus | None]:
92+
dryrun_info: specs.AppDryRunInfo | None = None,
93+
) -> tuple[str | None, specs.AppStatus | specs.AppDryRunInfo | None]:
9094
runner = runner or get_runner()
9195

9296
if dryrun:
@@ -100,13 +104,14 @@ def launch(
100104
CONSOLE.log("\n=== APPLICATION ===\n")
101105
CONSOLE.log(dryrun_info)
102106

103-
return None, None
107+
return None, dryrun_info
104108
else:
105109
app_handle = runner.run(
106110
executable,
107111
executor_name,
108112
cfg=executor, # type: ignore
109113
parent_run_id=parent_run_id,
114+
dryrun_info=dryrun_info,
110115
)
111116
logger.info(f"Launched app: {app_handle}")
112117
app_status = specs.AppStatus(state=specs.AppState.SUBMITTED)

0 commit comments

Comments
 (0)