Skip to content

Commit 03db88e

Browse files
authored
Minimize io operations in experiment, avoid persisting experiment to disk before run/dryrun (#150)
* Minimize io operations in experiment, avoid persisting experiment to disk before run/dryrun Signed-off-by: Hemil Desai <[email protected]> * delete directory at end of dryrun Signed-off-by: Hemil Desai <[email protected]> * Fixes Signed-off-by: Hemil Desai <[email protected]> --------- Signed-off-by: Hemil Desai <[email protected]>
1 parent d1ae06f commit 03db88e

File tree

9 files changed

+30
-24
lines changed

9 files changed

+30
-24
lines changed

src/nemo_run/core/execution/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,9 @@ def package_configs(self, *cfgs: tuple[str, str]) -> list[str]:
227227
filenames.append(filename)
228228
return filenames
229229

230+
def create_job_dir(self):
231+
os.makedirs(self.job_dir, exist_ok=True)
232+
230233
def cleanup(self, handle: str): ...
231234

232235

src/nemo_run/core/execution/dgxcloud.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,6 @@ def assign(
226226
self.experiment_dir = exp_dir
227227
self.job_dir = os.path.join(exp_dir, task_dir)
228228
self.experiment_id = exp_id
229-
os.makedirs(self.job_dir, exist_ok=True)
230229
assert any(
231230
map(
232231
lambda x: os.path.commonpath(

src/nemo_run/core/execution/docker.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,6 @@ def assign(
162162
self.experiment_id = exp_id
163163
self.experiment_dir = exp_dir
164164
self.job_dir = os.path.join(exp_dir, task_dir)
165-
os.makedirs(self.job_dir, exist_ok=True)
166165

167166
def nnodes(self) -> int:
168167
return 1

src/nemo_run/core/execution/local.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def assign(
4848
self.experiment_id = exp_id
4949
self.experiment_dir = exp_dir
5050
self.job_dir = os.path.join(exp_dir, task_dir)
51-
os.makedirs(self.job_dir, exist_ok=True)
5251

5352
def nnodes(self) -> int:
5453
return 1

src/nemo_run/core/execution/skypilot.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,8 +280,6 @@ def assign(
280280
self.job_dir = os.path.join(exp_dir, task_dir)
281281
self.experiment_id = exp_id
282282

283-
os.makedirs(self.job_dir, exist_ok=True)
284-
285283
def package(self, packager: Packager, job_name: str):
286284
assert self.experiment_id, "Executor not assigned to an experiment."
287285
if isinstance(packager, GitArchivePackager):

src/nemo_run/core/execution/slurm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,6 @@ def assign(
514514
self.job_dir = os.path.join(exp_dir, task_dir)
515515
self.experiment_id = exp_id
516516

517-
os.makedirs(self.job_dir, exist_ok=True)
518517
self.tunnel._set_job_dir(self.experiment_id)
519518

520519
def get_launcher_prefix(self) -> Optional[list[str]]:

src/nemo_run/run/experiment.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -310,10 +310,7 @@ def __init__(
310310
self._runner = get_runner()
311311

312312
if not _reconstruct:
313-
os.makedirs(self._exp_dir, exist_ok=False)
314-
315313
self.executor = executor if executor else LocalExecutor()
316-
self._save_config()
317314
else:
318315
assert isinstance(executor, Executor)
319316
self.executor = executor
@@ -334,6 +331,10 @@ def to_config(self) -> Config:
334331
log_level=self.log_level,
335332
)
336333

334+
def _save_experiment(self, exist_ok: bool = False):
335+
os.makedirs(self._exp_dir, exist_ok=exist_ok)
336+
self._save_config()
337+
337338
def _save_config(self):
338339
with open(os.path.join(self._exp_dir, self.__class__._CONFIG_FILE), "w+") as f:
339340
f.write(ZlibJSONSerializer().serialize(self.to_config()))
@@ -389,6 +390,13 @@ def _load_jobs(self) -> list[Job | JobGroup]:
389390

390391
return jobs
391392

393+
def _prepare(self, exist_ok: bool = False):
394+
self._save_experiment(exist_ok=exist_ok)
395+
for job in self.jobs:
396+
job.prepare()
397+
398+
self._save_jobs()
399+
392400
def _add_single_job(
393401
self,
394402
task: Union[Partial, Script],
@@ -434,7 +442,6 @@ def _add_single_job(
434442
plugin.assign(self._id)
435443
plugin.setup(cloned, executor)
436444

437-
job.prepare()
438445
self._jobs.append(job)
439446
return job.id
440447

@@ -482,7 +489,6 @@ def _add_job_group(
482489
assert isinstance(_executor, Executor)
483490
plugin.setup(task, _executor)
484491

485-
job_group.prepare()
486492
self._jobs.append(job_group)
487493
return job_group.id
488494

@@ -552,16 +558,17 @@ def add(
552558
dependencies=dependencies.copy() if dependencies else None,
553559
)
554560

555-
self._save_jobs()
556561
return job_id
557562

558-
def dryrun(self, log: bool = True):
563+
def dryrun(self, log: bool = True, exist_ok: bool = False, delete_exp_dir: bool = True):
559564
"""
560565
Logs the raw scripts that will be executed for each task.
561566
"""
562567
if log:
563568
self.console.log(f"[bold magenta]Experiment {self._id} dryrun...")
564569

570+
self._prepare(exist_ok=exist_ok)
571+
565572
for job in self.jobs:
566573
if isinstance(job, Job):
567574
if log:
@@ -571,6 +578,9 @@ def dryrun(self, log: bool = True):
571578
self.console.log(f"[bold magenta]Task Group {job.id}\n")
572579
job.launch(wait=False, runner=self._runner, dryrun=True, direct=False, log_dryrun=log)
573580

581+
if delete_exp_dir:
582+
shutil.rmtree(self._exp_dir)
583+
574584
def run(
575585
self,
576586
sequential: bool = False,
@@ -614,6 +624,9 @@ def run(
614624
self.console.log("[bold magenta]Experiment in inspection mode...")
615625
return
616626

627+
# Prepare experiment before running
628+
self._prepare()
629+
617630
if direct:
618631
self.console.log(
619632
"[bold magenta]Running the experiment with direct=True. "
@@ -637,8 +650,8 @@ def run(
637650
os.path.join(job.executor.job_dir, f"log_{job.id}_direct_run.out")
638651
):
639652
job.launch(wait=True, direct=True, runner=self._runner)
640-
self._save_jobs()
641653

654+
self._save_jobs()
642655
self._launched = any(map(lambda job: job.launched, self.jobs))
643656
self._direct = True
644657
return
@@ -669,7 +682,7 @@ def run(
669682
for i in range(1, len(self.jobs)):
670683
self.jobs[i].dependencies.append(self.jobs[i - 1].id)
671684

672-
self.dryrun(log=False)
685+
self.dryrun(log=False, exist_ok=True, delete_exp_dir=False)
673686
for tunnel in self.tunnels.values():
674687
if isinstance(tunnel, SSHTunnel):
675688
tunnel.connect()
@@ -746,14 +759,14 @@ def _run_dag(self, detach: bool, tail_logs: bool, executors: set[Executor]):
746759
job.executor.dependencies = deps # type: ignore
747760
job.launch(wait=False, runner=self._runner)
748761

749-
self._save_jobs()
750762
except Exception as e:
751763
self.console.log(f"Error running job {job.id}: {e}")
752764
raise e
753765

754766
if wait:
755767
self._wait_for_jobs(jobs=[job_map[node] for node in level])
756768

769+
self._save_jobs()
757770
self._launched = any(map(lambda job: job.launched, self.jobs))
758771
self._waited = wait
759772

@@ -955,7 +968,6 @@ def reset(self) -> "Experiment":
955968
old_id, old_exp_dir, old_launched = self._id, self._exp_dir, self._launched
956969
self._id = f"{self._title}_{int(time.time())}"
957970
self._exp_dir = os.path.join(NEMORUN_HOME, "experiments", self._title, self._id)
958-
os.makedirs(self._exp_dir, exist_ok=False)
959971
self._launched = False
960972
self._live_progress = None
961973

@@ -967,12 +979,9 @@ def reset(self) -> "Experiment":
967979
_current_experiment.set(self)
968980
_set_current_experiment = True
969981

970-
if "__main__.py" in os.listdir(old_exp_dir):
971-
shutil.copy(os.path.join(old_exp_dir, "__main__.py"), self._exp_dir)
972-
973982
try:
974983
if "__external_main__" not in sys.modules:
975-
maybe_load_external_main(self._exp_dir)
984+
maybe_load_external_main(old_exp_dir)
976985

977986
for job in jobs:
978987
if isinstance(job, Job):
@@ -1022,8 +1031,6 @@ def reset(self) -> "Experiment":
10221031
self._current_experiment_token = None
10231032

10241033
self._reconstruct = False
1025-
self._save_config()
1026-
10271034
return self
10281035

10291036
def _initialize_live_progress(self):

src/nemo_run/run/job.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def logs(self, runner: Runner, regex: str | None = None):
9292
)
9393

9494
def prepare(self):
95+
self.executor.create_job_dir()
9596
self._executable = package(
9697
self.id, self.task, executor=self.executor, serialize_to_file=True
9798
)
@@ -306,6 +307,7 @@ def logs(self, runner: Runner, regex: str | None = None):
306307
)
307308

308309
def prepare(self):
310+
self.executor.create_job_dir()
309311
self._executables: list[tuple[AppDef, Executor]] = []
310312
for i, task in enumerate(self.tasks):
311313
executor = self.executors if self._merge else self.executors[i] # type: ignore

test/core/execution/test_local.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_local_executor_assign():
3737

3838
assert executor.experiment_id == "test_exp"
3939
assert executor.job_dir == os.path.join(tmp_dir, "test_task")
40-
assert os.path.exists(executor.job_dir)
40+
assert not os.path.exists(executor.job_dir)
4141

4242

4343
def test_local_executor_nnodes():

0 commit comments

Comments
 (0)