Skip to content

Commit d8a94eb

Browse files
committed
done testing
1 parent 79343ef commit d8a94eb

File tree

5 files changed

+172
-57
lines changed

5 files changed

+172
-57
lines changed

dlrover/python/common/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
class BasicClass(object):
1616
LOG_LEVEL_ENV = "DLROVER_LOG_LEVEL"
17+
LOG_ROOT_DIR_ENV = "DLROVER_LOG_ROOT_DIR"
18+
LOG_AGENT_DIR_ENV = "DLROVER_LOG_AGENT_DIR"
1719

1820

1921
class PriorityClass(object):

dlrover/python/common/log.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import os
1616
import sys
1717
import typing
18+
from logging.handlers import RotatingFileHandler
1819

1920
from dlrover.python.common.constants import BasicClass
2021

@@ -31,7 +32,7 @@
3132
_ch = logging.StreamHandler(stream=sys.stderr)
3233
_ch.setFormatter(_DEFAULT_FORMATTER)
3334

34-
_DEFAULT_HANDLERS = [_ch]
35+
_DEFAULT_HANDLERS: typing.List[logging.Handler] = [_ch]
3536

3637
_LOGGER_CACHE: typing.Dict[str, logging.Logger] = {}
3738

@@ -43,14 +44,55 @@ def get_log_level():
4344
return log_level
4445

4546

46-
def get_logger(name, handlers=None, update=False):
47+
def get_base_log_dir():
48+
log_dir = os.getenv(BasicClass.LOG_ROOT_DIR_ENV, "")
49+
return log_dir
50+
51+
52+
def get_agent_log_dir():
53+
log_dir = os.getenv(BasicClass.LOG_AGENT_DIR_ENV, "")
54+
return log_dir
55+
56+
57+
def get_base_log_file():
58+
log_dir = get_base_log_dir()
59+
log_file = ""
60+
if log_dir:
61+
log_file = os.path.join(log_dir, "dlrover.log")
62+
os.makedirs(log_dir, exist_ok=True)
63+
return log_file
64+
65+
66+
def get_logger(
67+
name,
68+
handlers: typing.Optional[typing.List[logging.Handler]] = None,
69+
update=False,
70+
):
4771
__setup_extra_logger()
4872

4973
if name in _LOGGER_CACHE and not update:
5074
return _LOGGER_CACHE[name]
5175
logger = logging.getLogger(name)
5276
logger.setLevel(get_log_level())
53-
logger.handlers = handlers or _DEFAULT_HANDLERS
77+
78+
if handlers is None:
79+
base_log_file = get_base_log_file()
80+
if base_log_file:
81+
file_handler = RotatingFileHandler(
82+
base_log_file,
83+
maxBytes=200 * 1024 * 1024, # 200MB
84+
backupCount=3,
85+
)
86+
file_handler.setFormatter(_DEFAULT_FORMATTER)
87+
handlers = [file_handler] + _DEFAULT_HANDLERS
88+
else:
89+
handlers = _DEFAULT_HANDLERS
90+
elif len(handlers) == 0:
91+
handlers = _DEFAULT_HANDLERS
92+
else:
93+
handlers.extend(_DEFAULT_HANDLERS)
94+
95+
logger.handlers = list(handlers)
5496
logger.propagate = False
5597
return logger
5698

dlrover/python/elastic_agent/torch/training.py

Lines changed: 115 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import socket
2020
import subprocess
2121
import sys
22-
import tempfile
2322
import time
2423
import traceback
2524
import uuid
@@ -180,6 +179,79 @@ class StopWorkerTimeoutError(RuntimeError):
180179
pass
181180

182181

182+
class LogConfig:
183+
_log_dir: Optional[str] = None
184+
_redirects: Union[Std, Dict[int, Std]] = Std.NONE
185+
_tee: Union[Std, Dict[int, Std]] = Std.NONE
186+
187+
@classmethod
188+
def _parse_std_value(cls, val) -> Std:
189+
if val is None:
190+
return Std.NONE
191+
192+
if isinstance(val, str):
193+
try:
194+
return Std.from_str(val)
195+
except ValueError:
196+
return Std.NONE
197+
elif isinstance(val, Std):
198+
return val
199+
200+
return Std.NONE
201+
202+
def setup(self, log_dir, redirects=None, tee=None):
203+
if not log_dir:
204+
return
205+
206+
self._log_dir = log_dir
207+
208+
redirects = self._parse_std_value(redirects)
209+
tee = self._parse_std_value(tee)
210+
211+
if redirects == Std.NONE and tee == Std.NONE:
212+
# override default when log dir is specified
213+
self._redirects = Std.ALL
214+
self._tee = Std.ALL
215+
else:
216+
self._redirects = redirects
217+
self._tee = tee
218+
219+
logger.debug(
220+
f"Log config: {self._log_dir}-{self._redirects}-{self._tee}"
221+
)
222+
223+
@property
224+
def log_dir(self) -> Optional[str]:
225+
return self._log_dir
226+
227+
@property
228+
def redirects(self) -> Union[Std, Dict[int, Std]]:
229+
return self._redirects
230+
231+
@property
232+
def tee(self) -> Union[Std, Dict[int, Std]]:
233+
return self._tee
234+
235+
@property
236+
def logs_specs(self):
237+
if version_less_than_230():
238+
return {
239+
"log_dir": self.log_dir,
240+
"redirects": self.redirects,
241+
"tee": self.tee,
242+
}
243+
else:
244+
from torch.distributed.elastic.multiprocessing import (
245+
DefaultLogsSpecs,
246+
)
247+
248+
log_specs = DefaultLogsSpecs(
249+
log_dir=self.log_dir, redirects=self.redirects, tee=self.tee
250+
)
251+
log_specs._run_log_dir = self.log_dir
252+
return log_specs
253+
254+
183255
@dataclass
184256
class ElasticLaunchConfig(LaunchConfig):
185257
"""
@@ -216,15 +288,33 @@ class ElasticLaunchConfig(LaunchConfig):
216288
exclude_straggler: bool = False
217289
save_at_breakpoint: bool = False
218290
accelerator: str = ""
219-
log_dir: Optional[str] = None # Keep Compatibility with PyTorch>=2.3.0
220-
redirects: Union[Std, Dict[int, Std]] = Std.NONE
221-
tee: Union[Std, Dict[int, Std]] = Std.NONE
291+
log_config: LogConfig = LogConfig()
222292
training_log_file: str = ""
223293
failure_node_errors: str = ""
224294
numa_affinity: bool = False
225295
membind_policy: str = "none"
226296
ucp_device_type: str = "cpu"
227297

298+
def get_log_dir(self):
299+
return self.log_config.log_dir
300+
301+
def get_log_tee(self):
302+
return self.log_config.tee
303+
304+
def get_log_redirects(self):
305+
return self.log_config.redirects
306+
307+
def get_log_specs(self):
308+
return self.log_config.logs_specs
309+
310+
def setup_log(self, log_dir, redirects=None, tee=None):
311+
if log_dir:
312+
logger.info(f"Initiate specified log directory: {log_dir}.")
313+
314+
self.log_config.setup(log_dir, redirects=redirects, tee=tee)
315+
else:
316+
logger.info("No specified log directory is configured.")
317+
228318
def set_node_unit(self, node_unit):
229319
"""Set the number unit of nodes."""
230320
self.node_unit = node_unit
@@ -554,7 +644,6 @@ def __init__(
554644
spec: WorkerSpec,
555645
start_method="spawn",
556646
exit_barrier_timeout: float = 300,
557-
log_dir: Optional[str] = None,
558647
training_log_file: str = "",
559648
failure_node_errors: str = "",
560649
with_diagnostician: bool = True,
@@ -564,18 +653,26 @@ def __init__(
564653
spec=spec,
565654
exit_barrier_timeout=exit_barrier_timeout,
566655
)
656+
# compatible
657+
# https://github.com/pytorch/pytorch/blob/39901f229520a5256505ec24782f716ee7ddc843/torch/distributed/elastic/agent/server/local_elastic_agent.py#L148C9-L148C22
658+
self._log_dir = config.get_log_dir()
567659
else:
660+
logger.info(
661+
"Setup logging configuration for torch version>=230 with "
662+
f"log_dir: {config.get_log_dir()}, "
663+
f"redirections: {config.get_log_redirects()}, "
664+
f"tee: {config.get_log_tee()}, log_specs: {config.get_log_specs().__dict__}"
665+
)
568666
super().__init__(
569667
spec=spec,
570-
logs_specs=config.logs_specs,
668+
logs_specs=config.get_log_specs(),
571669
exit_barrier_timeout=exit_barrier_timeout,
572670
)
573671
self._node_rank = node_rank
574672
self._config = config
575673
self._entrypoint = entrypoint
576674
self._start_method = start_method
577675
self._pcontext: Optional[PContext] = None
578-
self._log_dir = log_dir or tempfile.mkdtemp(prefix="torchelastic_")
579676
self._worker_watchdog: Optional[timer.FileTimerServer] = None
580677
self._restart_count = 0
581678
self._remaining_failovers = self._remaining_restarts
@@ -1682,7 +1779,7 @@ def launch_agent(
16821779
f" rdzv_configs : {config.rdzv_configs}\n"
16831780
f" max_restarts : {config.max_restarts}\n"
16841781
f" monitor_interval : {config.monitor_interval}\n"
1685-
f" log_dir : {config.log_dir}\n"
1782+
f" log_dir : {config.get_log_dir()}\n"
16861783
f" metrics_cfg : {config.metrics_cfg}\n"
16871784
f" training_log : {config.training_log_file}\n"
16881785
f" failure_errors : {config.failure_node_errors}\n"
@@ -1712,7 +1809,6 @@ def launch_agent(
17121809
entrypoint=entrypoint,
17131810
spec=spec,
17141811
start_method=config.start_method,
1715-
log_dir=config.log_dir,
17161812
training_log_file=config.training_log_file,
17171813
failure_node_errors=config.failure_node_errors,
17181814
exit_barrier_timeout=900,
@@ -1826,10 +1922,16 @@ def _create_worker_spec(
18261922
master_addr=master_addr,
18271923
)
18281924

1925+
# for torch < 230, the tee and redirects config for log is located in spec
18291926
if version_less_than_230():
1830-
spec.redirects = config.redirects
1831-
spec.tee = config.tee
1832-
1927+
spec.redirects = config.get_log_redirects()
1928+
spec.tee = config.get_log_tee()
1929+
logger.info(
1930+
"Setup logging configuration for torch version<230 with "
1931+
f"log_dir: {config.get_log_dir()}, "
1932+
f"redirections: {config.get_log_redirects()}, "
1933+
f"tee: {config.get_log_tee()}"
1934+
)
18331935
return spec
18341936

18351937

@@ -1858,7 +1960,6 @@ def __init__(
18581960
spec: WorkerSpec,
18591961
start_method="spawn",
18601962
exit_barrier_timeout: float = 300,
1861-
log_dir: Optional[str] = None,
18621963
check_round=1,
18631964
):
18641965
super().__init__(
@@ -1868,10 +1969,8 @@ def __init__(
18681969
spec,
18691970
start_method,
18701971
exit_barrier_timeout,
1871-
log_dir,
18721972
with_diagnostician=False,
18731973
)
1874-
self._log_dir = log_dir or tempfile.mkdtemp(prefix="node_check_")
18751974
self._check_round = check_round
18761975
self._config: ElasticLaunchConfig = config
18771976

@@ -2063,7 +2162,7 @@ def _create_check_agent(
20632162
f" rdzv_configs : {config.rdzv_configs}\n"
20642163
f" max_restarts : {config.max_restarts}\n"
20652164
f" monitor_interval : {config.monitor_interval}\n"
2066-
f" log_dir : {config.log_dir}\n"
2165+
f" log_dir : {config.get_log_dir()}\n"
20672166
f" metrics_cfg : {config.metrics_cfg}\n"
20682167
)
20692168

0 commit comments

Comments
 (0)