1616
1717
1818def setup_logger (
19- output_file : Optional [str ] = None ,
19+ output_dir : str ,
20+ filename : Optional [str ] = None ,
21+ distributedlog_filename : Optional [str ] = None ,
2022 logging_config : Optional [Dict ] = None ,
21- output_dir : Optional [str ] = None ,
2223) -> None :
2324 # logging_config must be a dictionary object specifying the configuration
2425 # for the loggers to be used in auto-sklearn.
25- if logging_config is not None :
26- if output_file is not None :
27- logging_config ['handlers' ]['file_handler' ]['filename' ] = output_file
28- if output_dir is not None :
29- logging_config ['handlers' ]['distributed_logfile' ]['filename' ] = os .path .join (
30- output_dir , 'distributed.log'
31- )
32- logging .config .dictConfig (logging_config )
33- else :
26+ if logging_config is None :
3427 with open (os .path .join (os .path .dirname (__file__ ), 'logging.yaml' ), 'r' ) as fh :
3528 logging_config = yaml .safe_load (fh )
36- if output_file is not None :
37- logging_config ['handlers' ]['file_handler' ]['filename' ] = output_file
38- if output_dir is not None :
39- logging_config ['handlers' ]['distributed_logfile' ]['filename' ] = os .path .join (
40- output_dir , 'distributed.log'
41- )
42- logging .config .dictConfig (logging_config )
29+
30+ if filename is None :
31+ filename = logging_config ['handlers' ]['file_handler' ]['filename' ]
32+ logging_config ['handlers' ]['file_handler' ]['filename' ] = os .path .join (
33+ output_dir , filename
34+ )
35+
36+ if distributedlog_filename is None :
37+ distributedlog_filename = logging_config ['handlers' ]['distributed_logfile' ]['filename' ]
38+ logging_config ['handlers' ]['distributed_logfile' ]['filename' ] = os .path .join (
39+ output_dir , distributedlog_filename
40+ )
41+ logging .config .dictConfig (logging_config )
4342
4443
4544def _create_logger (name : str ) -> logging .Logger :
@@ -107,15 +106,22 @@ def isEnabledFor(self, level: int) -> bool:
107106
108107
109108def get_named_client_logger (
109+ output_dir : str ,
110110 name : str ,
111111 host : str = 'localhost' ,
112112 port : int = logging .handlers .DEFAULT_TCP_LOGGING_PORT ,
113113) -> 'PicklableClientLogger' :
114- logger = PicklableClientLogger (name , host , port )
114+ logger = PicklableClientLogger (
115+ output_dir = output_dir ,
116+ name = name ,
117+ host = host ,
118+ port = port
119+ )
115120 return logger
116121
117122
118123def _get_named_client_logger (
124+ output_dir : str ,
119125 name : str ,
120126 host : str = 'localhost' ,
121127 port : int = logging .handlers .DEFAULT_TCP_LOGGING_PORT ,
@@ -133,6 +139,8 @@ def _get_named_client_logger(
133139
134140 Parameters
135141 ----------
142+ outputdir: (str)
143+ The path where the log files are going to be dumped
136144 name: (str)
137145 the name of the logger, used to tag the messages in the main log
138146 host: (str)
@@ -143,7 +151,7 @@ def _get_named_client_logger(
143151 local_loger: a logger object that has a socket handler
144152 """
145153 # Setup the logger configuration
146- setup_logger ()
154+ setup_logger (output_dir = output_dir )
147155
148156 local_logger = _create_logger (name )
149157
@@ -159,11 +167,17 @@ def _get_named_client_logger(
159167
160168class PicklableClientLogger (PickableLoggerAdapter ):
161169
162- def __init__ (self , name : str , host : str , port : int ):
170+ def __init__ (self , output_dir : str , name : str , host : str , port : int ):
171+ self .output_dir = output_dir
163172 self .name = name
164173 self .host = host
165174 self .port = port
166- self .logger = _get_named_client_logger (name , host , port )
175+ self .logger = _get_named_client_logger (
176+ output_dir = output_dir ,
177+ name = name ,
178+ host = host ,
179+ port = port
180+ )
167181
168182 def __getstate__ (self ) -> Dict [str , Any ]:
169183 """
@@ -174,7 +188,12 @@ def __getstate__(self) -> Dict[str, Any]:
174188 Dictionary, representing the object state to be pickled. Ignores
175189 the self.logger field and only returns the logger name.
176190 """
177- return {'name' : self .name , 'host' : self .host , 'port' : self .port }
191+ return {
192+ 'name' : self .name ,
193+ 'host' : self .host ,
194+ 'port' : self .port ,
195+ 'output_dir' : self .output_dir ,
196+ }
178197
179198 def __setstate__ (self , state : Dict [str , Any ]) -> None :
180199 """
@@ -189,7 +208,13 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
189208 self .name = state ['name' ]
190209 self .host = state ['host' ]
191210 self .port = state ['port' ]
192- self .logger = _get_named_client_logger (self .name , self .host , self .port )
211+ self .output_dir = state ['output_dir' ]
212+ self .logger = _get_named_client_logger (
213+ name = self .name ,
214+ host = self .host ,
215+ port = self .port ,
216+ output_dir = self .output_dir ,
217+ )
193218
194219
195220class LogRecordStreamHandler (socketserver .StreamRequestHandler ):
@@ -242,11 +267,13 @@ def start_log_server(
242267 logname : str ,
243268 event : threading .Event ,
244269 port : multiprocessing .Value ,
245- output_file : str ,
270+ filename : str ,
246271 logging_config : Dict ,
247272 output_dir : str ,
248273) -> None :
249- setup_logger (output_file , logging_config , output_dir )
274+ setup_logger (filename = filename ,
275+ logging_config = logging_config ,
276+ output_dir = output_dir )
250277
251278 while True :
252279 # Loop until we find a valid port
0 commit comments