77from queue import Empty
88import time
99import traceback
10- from typing import Dict , List , Optional , Tuple , Union
10+ from typing import Any , Callable , Dict , List , Optional , Tuple , Type , Union , cast
1111
1212from ConfigSpace import Configuration
1313import numpy as np
1414import pynisher
1515from smac .runhistory .runhistory import RunInfo , RunValue
16+ from smac .stats .stats import Stats
1617from smac .tae import StatusType , TAEAbortException
1718from smac .tae .execute_func import AbstractTAFunc
1819
2324import autosklearn .evaluation .train_evaluator
2425import autosklearn .evaluation .test_evaluator
2526import autosklearn .evaluation .util
26- from autosklearn .util .logging_ import get_named_client_logger
27+ from autosklearn .evaluation .train_evaluator import TYPE_ADDITIONAL_INFO
28+ from autosklearn .util .backend import Backend
29+ from autosklearn .util .logging_ import PickableLoggerAdapter , get_named_client_logger
2730from autosklearn .util .parallel import preload_modules
2831
2932
30- def fit_predict_try_except_decorator (ta , queue , cost_for_crash , ** kwargs ):
33+ def fit_predict_try_except_decorator (
34+ ta : Callable ,
35+ queue : multiprocessing .Queue ,
36+ cost_for_crash : float ,
37+ ** kwargs : Any ) -> None :
3138
3239 try :
3340 return ta (queue = queue , ** kwargs )
@@ -66,7 +73,7 @@ def fit_predict_try_except_decorator(ta, queue, cost_for_crash, **kwargs):
6673 queue .close ()
6774
6875
69- def get_cost_of_crash (metric ) :
76+ def get_cost_of_crash (metric : Scorer ) -> float :
7077
7178 # The metric must always be defined to extract optimum/worst
7279 if not isinstance (metric , Scorer ):
@@ -85,8 +92,11 @@ def get_cost_of_crash(metric):
8592 return worst_possible_result
8693
8794
88- def _encode_exit_status (exit_status ):
95+ def _encode_exit_status (exit_status : Union [str , int , Type [BaseException ]]
96+ ) -> Union [str , int ]:
8997 try :
98+ # If it can be dumped, then it is int
99+ exit_status = cast (int , exit_status )
90100 json .dumps (exit_status )
91101 return exit_status
92102 except (TypeError , OverflowError ):
@@ -97,13 +107,31 @@ def _encode_exit_status(exit_status):
97107# easier debugging of potential crashes
98108class ExecuteTaFuncWithQueue (AbstractTAFunc ):
99109
100- def __init__ (self , backend , autosklearn_seed , resampling_strategy , metric ,
101- cost_for_crash , abort_on_first_run_crash , port , pynisher_context ,
102- initial_num_run = 1 , stats = None ,
103- run_obj = 'quality' , par_factor = 1 , scoring_functions = None ,
104- output_y_hat_optimization = True , include = None , exclude = None ,
105- memory_limit = None , disable_file_output = False , init_params = None ,
106- budget_type = None , ta = False , ** resampling_strategy_args ):
110+ def __init__ (
111+ self ,
112+ backend : Backend ,
113+ autosklearn_seed : int ,
114+ resampling_strategy : Union [str , BaseCrossValidator , _RepeatedSplits , BaseShuffleSplit ],
115+ metric : Scorer ,
116+ cost_for_crash : float ,
117+ abort_on_first_run_crash : bool ,
118+ port : int ,
119+ pynisher_context : str ,
120+ initial_num_run : int = 1 ,
121+ stats : Optional [Stats ] = None ,
122+ run_obj : str = 'quality' ,
123+ par_factor : int = 1 ,
124+ scoring_functions : Optional [List [Scorer ]] = None ,
125+ output_y_hat_optimization : bool = True ,
126+ include : Optional [List [str ]] = None ,
127+ exclude : Optional [List [str ]] = None ,
128+ memory_limit : Optional [int ] = None ,
129+ disable_file_output : bool = False ,
130+ init_params : Optional [Dict [str , Any ]] = None ,
131+ budget_type : Optional [str ] = None ,
132+ ta : Optional [Callable ] = None ,
133+ ** resampling_strategy_args : Any ,
134+ ):
107135
108136 if resampling_strategy == 'holdout' :
109137 eval_function = autosklearn .evaluation .train_evaluator .eval_holdout
@@ -180,7 +208,7 @@ def __init__(self, backend, autosklearn_seed, resampling_strategy, metric,
180208 self .port = port
181209 self .pynisher_context = pynisher_context
182210 if self .port is None :
183- self .logger = logging .getLogger ("TAE" )
211+ self .logger : Union [ logging . Logger , PickableLoggerAdapter ] = logging .getLogger ("TAE" )
184212 else :
185213 self .logger = get_named_client_logger (
186214 name = "TAE" ,
@@ -261,6 +289,10 @@ def run(
261289 instance_specific : Optional [str ] = None ,
262290 ) -> Tuple [StatusType , float , float , Dict [str , Union [int , float , str , Dict , List , Tuple ]]]:
263291
292+ # Additional information of each of the tae executions
293+ # Defined upfront for mypy
294+ additional_run_info : TYPE_ADDITIONAL_INFO = {}
295+
264296 context = multiprocessing .get_context (self .pynisher_context )
265297 preload_modules (context )
266298 queue = context .Queue ()
@@ -272,7 +304,7 @@ def run(
272304 init_params .update (self .init_params )
273305
274306 if self .port is None :
275- logger = logging .getLogger ("pynisher" )
307+ logger : Union [ logging . Logger , PickableLoggerAdapter ] = logging .getLogger ("pynisher" )
276308 else :
277309 logger = get_named_client_logger (
278310 name = "pynisher" ,
@@ -320,11 +352,11 @@ def run(
320352 except Exception as e :
321353 exception_traceback = traceback .format_exc ()
322354 error_message = repr (e )
323- additional_info = {
355+ additional_run_info . update ( {
324356 'traceback' : exception_traceback ,
325357 'error' : error_message
326- }
327- return StatusType .CRASHED , self .cost_for_crash , 0.0 , additional_info
358+ })
359+ return StatusType .CRASHED , self .worst_possible_result , 0.0 , additional_run_info
328360
329361 if obj .exit_status in (pynisher .TimeoutException , pynisher .MemorylimitException ):
330362 # Even if the pynisher thinks that a timeout or memout occured,
@@ -359,7 +391,7 @@ def run(
359391 elif obj .exit_status is pynisher .MemorylimitException :
360392 status = StatusType .MEMOUT
361393 additional_run_info = {
362- ' error' : ' Memout (used more than %d MB).' % self .memory_limit
394+ " error" : " Memout (used more than {} MB)." . format ( self .memory_limit )
363395 }
364396 else :
365397 raise ValueError (obj .exit_status )
0 commit comments