1919import socket
2020import subprocess
2121import sys
22- import tempfile
2322import time
2423import traceback
2524import uuid
@@ -180,6 +179,79 @@ class StopWorkerTimeoutError(RuntimeError):
180179 pass
181180
182181
182+ class LogConfig :
183+ _log_dir : Optional [str ] = None
184+ _redirects : Union [Std , Dict [int , Std ]] = Std .NONE
185+ _tee : Union [Std , Dict [int , Std ]] = Std .NONE
186+
187+ @classmethod
188+ def _parse_std_value (cls , val ) -> Std :
189+ if val is None :
190+ return Std .NONE
191+
192+ if isinstance (val , str ):
193+ try :
194+ return Std .from_str (val )
195+ except ValueError :
196+ return Std .NONE
197+ elif isinstance (val , Std ):
198+ return val
199+
200+ return Std .NONE
201+
202+ def setup (self , log_dir , redirects = None , tee = None ):
203+ if not log_dir :
204+ return
205+
206+ self ._log_dir = log_dir
207+
208+ redirects = self ._parse_std_value (redirects )
209+ tee = self ._parse_std_value (tee )
210+
211+ if redirects == Std .NONE and tee == Std .NONE :
212+ # override default when log dir is specified
213+ self ._redirects = Std .ALL
214+ self ._tee = Std .ALL
215+ else :
216+ self ._redirects = redirects
217+ self ._tee = tee
218+
219+ logger .debug (
220+ f"Log config: { self ._log_dir } -{ self ._redirects } -{ self ._tee } "
221+ )
222+
223+ @property
224+ def log_dir (self ) -> Optional [str ]:
225+ return self ._log_dir
226+
227+ @property
228+ def redirects (self ) -> Union [Std , Dict [int , Std ]]:
229+ return self ._redirects
230+
231+ @property
232+ def tee (self ) -> Union [Std , Dict [int , Std ]]:
233+ return self ._tee
234+
235+ @property
236+ def logs_specs (self ):
237+ if version_less_than_230 ():
238+ return {
239+ "log_dir" : self .log_dir ,
240+ "redirects" : self .redirects ,
241+ "tee" : self .tee ,
242+ }
243+ else :
244+ from torch .distributed .elastic .multiprocessing import (
245+ DefaultLogsSpecs ,
246+ )
247+
248+ log_specs = DefaultLogsSpecs (
249+ log_dir = self .log_dir , redirects = self .redirects , tee = self .tee
250+ )
251+ log_specs ._run_log_dir = self .log_dir
252+ return log_specs
253+
254+
183255@dataclass
184256class ElasticLaunchConfig (LaunchConfig ):
185257 """
@@ -216,15 +288,33 @@ class ElasticLaunchConfig(LaunchConfig):
216288 exclude_straggler : bool = False
217289 save_at_breakpoint : bool = False
218290 accelerator : str = ""
219- log_dir : Optional [str ] = None # Keep Compatibility with PyTorch>=2.3.0
220- redirects : Union [Std , Dict [int , Std ]] = Std .NONE
221- tee : Union [Std , Dict [int , Std ]] = Std .NONE
291+ log_config : LogConfig = LogConfig ()
222292 training_log_file : str = ""
223293 failure_node_errors : str = ""
224294 numa_affinity : bool = False
225295 membind_policy : str = "none"
226296 ucp_device_type : str = "cpu"
227297
298+ def get_log_dir (self ):
299+ return self .log_config .log_dir
300+
301+ def get_log_tee (self ):
302+ return self .log_config .tee
303+
304+ def get_log_redirects (self ):
305+ return self .log_config .redirects
306+
307+ def get_log_specs (self ):
308+ return self .log_config .logs_specs
309+
310+ def setup_log (self , log_dir , redirects = None , tee = None ):
311+ if log_dir :
312+ logger .info (f"Initiate specified log directory: { log_dir } ." )
313+
314+ self .log_config .setup (log_dir , redirects = redirects , tee = tee )
315+ else :
316+ logger .info ("No specified log directory is configured." )
317+
228318 def set_node_unit (self , node_unit ):
229319 """Set the number unit of nodes."""
230320 self .node_unit = node_unit
@@ -554,7 +644,6 @@ def __init__(
554644 spec : WorkerSpec ,
555645 start_method = "spawn" ,
556646 exit_barrier_timeout : float = 300 ,
557- log_dir : Optional [str ] = None ,
558647 training_log_file : str = "" ,
559648 failure_node_errors : str = "" ,
560649 with_diagnostician : bool = True ,
@@ -564,18 +653,26 @@ def __init__(
564653 spec = spec ,
565654 exit_barrier_timeout = exit_barrier_timeout ,
566655 )
656+ # compatible
657+ # https://github.com/pytorch/pytorch/blob/39901f229520a5256505ec24782f716ee7ddc843/torch/distributed/elastic/agent/server/local_elastic_agent.py#L148C9-L148C22
658+ self ._log_dir = config .get_log_dir ()
567659 else :
660+ logger .info (
661+ "Setup logging configuration for torch version>=230 with "
662+ f"log_dir: { config .get_log_dir ()} , "
663+ f"redirections: { config .get_log_redirects ()} , "
664+ f"tee: { config .get_log_tee ()} , log_specs: { config .get_log_specs ().__dict__ } "
665+ )
568666 super ().__init__ (
569667 spec = spec ,
570- logs_specs = config .logs_specs ,
668+ logs_specs = config .get_log_specs () ,
571669 exit_barrier_timeout = exit_barrier_timeout ,
572670 )
573671 self ._node_rank = node_rank
574672 self ._config = config
575673 self ._entrypoint = entrypoint
576674 self ._start_method = start_method
577675 self ._pcontext : Optional [PContext ] = None
578- self ._log_dir = log_dir or tempfile .mkdtemp (prefix = "torchelastic_" )
579676 self ._worker_watchdog : Optional [timer .FileTimerServer ] = None
580677 self ._restart_count = 0
581678 self ._remaining_failovers = self ._remaining_restarts
@@ -1682,7 +1779,7 @@ def launch_agent(
16821779 f" rdzv_configs : { config .rdzv_configs } \n "
16831780 f" max_restarts : { config .max_restarts } \n "
16841781 f" monitor_interval : { config .monitor_interval } \n "
1685- f" log_dir : { config .log_dir } \n "
1782+ f" log_dir : { config .get_log_dir () } \n "
16861783 f" metrics_cfg : { config .metrics_cfg } \n "
16871784 f" training_log : { config .training_log_file } \n "
16881785 f" failure_errors : { config .failure_node_errors } \n "
@@ -1712,7 +1809,6 @@ def launch_agent(
17121809 entrypoint = entrypoint ,
17131810 spec = spec ,
17141811 start_method = config .start_method ,
1715- log_dir = config .log_dir ,
17161812 training_log_file = config .training_log_file ,
17171813 failure_node_errors = config .failure_node_errors ,
17181814 exit_barrier_timeout = 900 ,
@@ -1826,10 +1922,16 @@ def _create_worker_spec(
18261922 master_addr = master_addr ,
18271923 )
18281924
1925+ # for torch < 230, the tee and redirects config for log is located in spec
18291926 if version_less_than_230 ():
1830- spec .redirects = config .redirects
1831- spec .tee = config .tee
1832-
1927+ spec .redirects = config .get_log_redirects ()
1928+ spec .tee = config .get_log_tee ()
1929+ logger .info (
1930+ "Setup logging configuration for torch version<230 with "
1931+ f"log_dir: { config .get_log_dir ()} , "
1932+ f"redirections: { config .get_log_redirects ()} , "
1933+ f"tee: { config .get_log_tee ()} "
1934+ )
18331935 return spec
18341936
18351937
@@ -1858,7 +1960,6 @@ def __init__(
18581960 spec : WorkerSpec ,
18591961 start_method = "spawn" ,
18601962 exit_barrier_timeout : float = 300 ,
1861- log_dir : Optional [str ] = None ,
18621963 check_round = 1 ,
18631964 ):
18641965 super ().__init__ (
@@ -1868,10 +1969,8 @@ def __init__(
18681969 spec ,
18691970 start_method ,
18701971 exit_barrier_timeout ,
1871- log_dir ,
18721972 with_diagnostician = False ,
18731973 )
1874- self ._log_dir = log_dir or tempfile .mkdtemp (prefix = "node_check_" )
18751974 self ._check_round = check_round
18761975 self ._config : ElasticLaunchConfig = config
18771976
@@ -2063,7 +2162,7 @@ def _create_check_agent(
20632162 f" rdzv_configs : { config .rdzv_configs } \n "
20642163 f" max_restarts : { config .max_restarts } \n "
20652164 f" monitor_interval : { config .monitor_interval } \n "
2066- f" log_dir : { config .log_dir } \n "
2165+ f" log_dir : { config .get_log_dir () } \n "
20672166 f" metrics_cfg : { config .metrics_cfg } \n "
20682167 )
20692168
0 commit comments