@@ -310,10 +310,7 @@ def __init__(
310310 self ._runner = get_runner ()
311311
312312 if not _reconstruct :
313- os .makedirs (self ._exp_dir , exist_ok = False )
314-
315313 self .executor = executor if executor else LocalExecutor ()
316- self ._save_config ()
317314 else :
318315 assert isinstance (executor , Executor )
319316 self .executor = executor
@@ -334,6 +331,10 @@ def to_config(self) -> Config:
334331 log_level = self .log_level ,
335332 )
336333
334+ def _save_experiment (self , exist_ok : bool = False ):
335+ os .makedirs (self ._exp_dir , exist_ok = exist_ok )
336+ self ._save_config ()
337+
337338 def _save_config (self ):
338339 with open (os .path .join (self ._exp_dir , self .__class__ ._CONFIG_FILE ), "w+" ) as f :
339340 f .write (ZlibJSONSerializer ().serialize (self .to_config ()))
@@ -389,6 +390,13 @@ def _load_jobs(self) -> list[Job | JobGroup]:
389390
390391 return jobs
391392
393+ def _prepare (self , exist_ok : bool = False ):
394+ self ._save_experiment (exist_ok = exist_ok )
395+ for job in self .jobs :
396+ job .prepare ()
397+
398+ self ._save_jobs ()
399+
392400 def _add_single_job (
393401 self ,
394402 task : Union [Partial , Script ],
@@ -434,7 +442,6 @@ def _add_single_job(
434442 plugin .assign (self ._id )
435443 plugin .setup (cloned , executor )
436444
437- job .prepare ()
438445 self ._jobs .append (job )
439446 return job .id
440447
@@ -482,7 +489,6 @@ def _add_job_group(
482489 assert isinstance (_executor , Executor )
483490 plugin .setup (task , _executor )
484491
485- job_group .prepare ()
486492 self ._jobs .append (job_group )
487493 return job_group .id
488494
@@ -552,16 +558,17 @@ def add(
552558 dependencies = dependencies .copy () if dependencies else None ,
553559 )
554560
555- self ._save_jobs ()
556561 return job_id
557562
558- def dryrun (self , log : bool = True ):
563+ def dryrun (self , log : bool = True , exist_ok : bool = False , delete_exp_dir : bool = True ):
559564 """
560565 Logs the raw scripts that will be executed for each task.
561566 """
562567 if log :
563568 self .console .log (f"[bold magenta]Experiment { self ._id } dryrun..." )
564569
570+ self ._prepare (exist_ok = exist_ok )
571+
565572 for job in self .jobs :
566573 if isinstance (job , Job ):
567574 if log :
@@ -571,6 +578,9 @@ def dryrun(self, log: bool = True):
571578 self .console .log (f"[bold magenta]Task Group { job .id } \n " )
572579 job .launch (wait = False , runner = self ._runner , dryrun = True , direct = False , log_dryrun = log )
573580
581+ if delete_exp_dir :
582+ shutil .rmtree (self ._exp_dir )
583+
574584 def run (
575585 self ,
576586 sequential : bool = False ,
@@ -614,6 +624,9 @@ def run(
614624 self .console .log ("[bold magenta]Experiment in inspection mode..." )
615625 return
616626
627+ # Prepare experiment before running
628+ self ._prepare ()
629+
617630 if direct :
618631 self .console .log (
619632 "[bold magenta]Running the experiment with direct=True. "
@@ -637,8 +650,8 @@ def run(
637650 os .path .join (job .executor .job_dir , f"log_{ job .id } _direct_run.out" )
638651 ):
639652 job .launch (wait = True , direct = True , runner = self ._runner )
640- self ._save_jobs ()
641653
654+ self ._save_jobs ()
642655 self ._launched = any (map (lambda job : job .launched , self .jobs ))
643656 self ._direct = True
644657 return
@@ -669,7 +682,7 @@ def run(
669682 for i in range (1 , len (self .jobs )):
670683 self .jobs [i ].dependencies .append (self .jobs [i - 1 ].id )
671684
672- self .dryrun (log = False )
685+ self .dryrun (log = False , exist_ok = True , delete_exp_dir = False )
673686 for tunnel in self .tunnels .values ():
674687 if isinstance (tunnel , SSHTunnel ):
675688 tunnel .connect ()
@@ -746,14 +759,14 @@ def _run_dag(self, detach: bool, tail_logs: bool, executors: set[Executor]):
746759 job .executor .dependencies = deps # type: ignore
747760 job .launch (wait = False , runner = self ._runner )
748761
749- self ._save_jobs ()
750762 except Exception as e :
751763 self .console .log (f"Error running job { job .id } : { e } " )
752764 raise e
753765
754766 if wait :
755767 self ._wait_for_jobs (jobs = [job_map [node ] for node in level ])
756768
769+ self ._save_jobs ()
757770 self ._launched = any (map (lambda job : job .launched , self .jobs ))
758771 self ._waited = wait
759772
@@ -955,7 +968,6 @@ def reset(self) -> "Experiment":
955968 old_id , old_exp_dir , old_launched = self ._id , self ._exp_dir , self ._launched
956969 self ._id = f"{ self ._title } _{ int (time .time ())} "
957970 self ._exp_dir = os .path .join (NEMORUN_HOME , "experiments" , self ._title , self ._id )
958- os .makedirs (self ._exp_dir , exist_ok = False )
959971 self ._launched = False
960972 self ._live_progress = None
961973
@@ -967,12 +979,9 @@ def reset(self) -> "Experiment":
967979 _current_experiment .set (self )
968980 _set_current_experiment = True
969981
970- if "__main__.py" in os .listdir (old_exp_dir ):
971- shutil .copy (os .path .join (old_exp_dir , "__main__.py" ), self ._exp_dir )
972-
973982 try :
974983 if "__external_main__" not in sys .modules :
975- maybe_load_external_main (self . _exp_dir )
984+ maybe_load_external_main (old_exp_dir )
976985
977986 for job in jobs :
978987 if isinstance (job , Job ):
@@ -1022,8 +1031,6 @@ def reset(self) -> "Experiment":
10221031 self ._current_experiment_token = None
10231032
10241033 self ._reconstruct = False
1025- self ._save_config ()
1026-
10271034 return self
10281035
10291036 def _initialize_live_progress (self ):
0 commit comments