3535from rich .console import Group
3636from rich .live import Live
3737from rich .panel import Panel
38- from rich .progress import BarColumn , Progress , SpinnerColumn
38+ from rich .progress import BarColumn , Progress , SpinnerColumn , TaskID , TimeElapsedColumn
3939from rich .progress import Task as RichTask
40- from rich .progress import TaskID , TimeElapsedColumn
4140from rich .syntax import Syntax
4241from torchx .specs .api import AppState
4342
@@ -303,6 +302,9 @@ def __init__(
303302 base_dir : str | None = None ,
304303 clean_mode : bool = False ,
305304 enable_goodbye_message : bool = True ,
305+ threadpool_workers : int = 16 ,
306+ skip_status_at_exit : bool = False ,
307+ serialize_metadata_for_scripts : bool = True ,
306308 ) -> None :
307309 """
308310 Initializes an experiment run by creating its metadata directory and saving the experiment config.
@@ -328,6 +330,9 @@ def __init__(
328330 self ._title = title
329331 self ._id = id or f"{ title } _{ int (time .time ())} "
330332 self ._enable_goodbye_message = enable_goodbye_message
333+ self ._threadpool_workers = threadpool_workers
334+ self ._skip_status_at_exit = skip_status_at_exit
335+ self ._serialize_metadata_for_scripts = serialize_metadata_for_scripts
331336
332337 base_dir = str (base_dir or get_nemorun_home ())
333338 self ._exp_dir = os .path .join (base_dir , "experiments" , title , self ._id )
@@ -359,6 +364,10 @@ def to_config(self) -> Config:
359364 executor = self .executor .to_config (),
360365 log_level = self .log_level ,
361366 clean_mode = self .clean_mode ,
367+ threadpool_workers = self ._threadpool_workers ,
368+ enable_goodbye_message = self ._enable_goodbye_message ,
369+ skip_status_at_exit = self ._skip_status_at_exit ,
370+ serialize_metadata_for_scripts = self ._serialize_metadata_for_scripts ,
362371 )
363372
364373 def _save_experiment (self , exist_ok : bool = False ):
@@ -422,8 +431,9 @@ def _load_jobs(self) -> list[Job | JobGroup]:
422431
423432 def _prepare (self , exist_ok : bool = False ):
424433 self ._save_experiment (exist_ok = exist_ok )
434+
425435 for job in self .jobs :
426- job .prepare ()
436+ job .prepare (serialize_metadata_for_scripts = self . _serialize_metadata_for_scripts )
427437
428438 self ._save_jobs ()
429439
@@ -769,7 +779,15 @@ def _run_dag(self, detach: bool, tail_logs: bool, executors: set[Executor]):
769779 self .detach = detach
770780
771781 for level in order :
772- for _ , node in enumerate (level ):
782+ # Launch jobs in this level concurrently since they are independent
783+
784+ def _set_context (ctx : contextvars .Context ):
785+ for var , value in ctx .items ():
786+ var .set (value )
787+
788+ ctx = contextvars .copy_context ()
789+
790+ def _launch (node : str ):
773791 job : Job | JobGroup = job_map [node ]
774792 self .console .log (f"[bold cyan]Launching job { job .id } for experiment { self ._title } " )
775793 if tail_logs :
@@ -787,14 +805,24 @@ def _run_dag(self, detach: bool, tail_logs: bool, executors: set[Executor]):
787805 deps .append (handle )
788806
789807 job .executor .dependencies = deps # type: ignore
808+
790809 job .launch (wait = False , runner = self ._runner )
810+ return job
791811
792812 except Exception as e :
793813 self .console .log (f"Error running job { job .id } : { e } " )
794814 raise e
795815
816+ launched_jobs : list [Job | JobGroup ] = []
817+ with ThreadPoolExecutor (
818+ initializer = _set_context , initargs = (ctx ,), max_workers = self ._threadpool_workers
819+ ) as pool :
820+ futures = [pool .submit (_launch , node ) for node in level ]
821+ for future in as_completed (futures ):
822+ launched_jobs .append (future .result ())
823+
796824 if wait :
797- self ._wait_for_jobs (jobs = [ job_map [ node ] for node in level ] )
825+ self ._wait_for_jobs (jobs = launched_jobs )
798826
799827 self ._save_jobs ()
800828 self ._launched = any (map (lambda job : job .launched , self .jobs ))
@@ -840,7 +868,21 @@ def set_context(context: contextvars.Context):
840868 finally :
841869 job .cleanup ()
842870
843- def status (self , return_dict : bool = False ) -> Optional [dict [str , str ]]:
871+ def _initialize_tunnels (self , extract_from_executors : bool = False ):
872+ if extract_from_executors :
873+ for job in self .jobs :
874+ if (
875+ isinstance (job .executor , SlurmExecutor )
876+ and job .executor .tunnel .key not in self .tunnels
877+ ):
878+ self .tunnels [job .executor .tunnel .key ] = job .executor .tunnel
879+
880+ for tunnel in self .tunnels .values ():
881+ if isinstance (tunnel , SSHTunnel ):
882+ tunnel .connect ()
883+ assert tunnel .session , f"SSH tunnel { tunnel .key } failed to connect."
884+
885+ def status (self , return_dict : bool = False ) -> Optional [dict [str , dict [str , str ]]]:
844886 """
845887 Prints a table specifying the status of all tasks.
846888
@@ -880,6 +922,7 @@ def _get_job_info_and_dict(
880922 "status" : job .status (runner = self ._runner ),
881923 "executor" : job .executor .info (),
882924 "job_id" : app_id ,
925+ "handle" : job .handle ,
883926 "local_dir" : job .executor .job_dir ,
884927 }
885928
@@ -902,13 +945,34 @@ def _get_job_info_and_dict(
902945 job_info .extend (directory_info )
903946 return job_info , job_dict
904947
948+ self ._initialize_tunnels (extract_from_executors = True )
905949 try :
906950 result_dict = {}
907- job_infos = []
908- for i , job in enumerate (self .jobs ):
909- job_info , job_dict = _get_job_info_and_dict (i , job )
910- job_infos .append (Group (* job_info ))
911- result_dict [job .id ] = job_dict
951+ job_infos : list [Group | None ] = [None ] * len (self .jobs )
952+
953+ # Parallelize IO-bound status retrieval across jobs
954+ def _collect (arg ):
955+ idx , job = arg
956+ job_info , job_dict = _get_job_info_and_dict (idx , job )
957+ return idx , job .id , job_info , job_dict
958+
959+ # Propagate context variables to worker threads so helpers that rely on them keep working
960+ def _set_context (ctx : contextvars .Context ):
961+ for var , value in ctx .items ():
962+ var .set (value )
963+
964+ ctx = contextvars .copy_context ()
965+ with ThreadPoolExecutor (
966+ initializer = _set_context , initargs = (ctx ,), max_workers = self ._threadpool_workers
967+ ) as pool :
968+ futures = [pool .submit (_collect , (idx , job )) for idx , job in enumerate (self .jobs )]
969+ for future in as_completed (futures ):
970+ idx , job_id , job_info , job_dict = future .result ()
971+ job_infos [idx ] = Group (* job_info )
972+ result_dict [job_id ] = job_dict
973+
974+ # Remove potential None slots (should not occur)
975+ job_infos = [ji for ji in job_infos if ji is not None ]
912976
913977 if return_dict :
914978 return result_dict
@@ -1142,7 +1206,7 @@ def __exit__(self, exc_type, exc_value, tb):
11421206 "Ephemeral logs and artifacts may be lost." ,
11431207 )
11441208
1145- if self ._launched :
1209+ if self ._launched and not self . _skip_status_at_exit :
11461210 self .status ()
11471211 return
11481212
@@ -1151,20 +1215,23 @@ def __exit__(self, exc_type, exc_value, tb):
11511215 self .console .rule (
11521216 f"[bold magenta]Direct run Experiment { self ._id } " ,
11531217 )
1154- self .status ()
1218+ if not self ._skip_status_at_exit :
1219+ self .status ()
11551220 return
11561221
11571222 if hasattr (self , "_waited" ) and self ._waited :
11581223 self .console .rule (
11591224 f"[bold magenta]Done waiting for Experiment { self ._id } " ,
11601225 )
1161- self .status ()
1226+ if not self ._skip_status_at_exit :
1227+ self .status ()
11621228 return
11631229
11641230 self .console .rule (
11651231 f"[bold magenta]Waiting for Experiment { self ._id } to finish" ,
11661232 )
1167- self .status ()
1233+ if not self ._skip_status_at_exit :
1234+ self .status ()
11681235
11691236 self ._wait_for_jobs (jobs = self .jobs )
11701237 finally :
0 commit comments