Skip to content

Commit cda9876

Browse files
Enable Mypy in evaluation (except Train Evaluator) (#1077)
* Almost all files for evaluation * Feedback from PR * Feedback from comments * Solving rebase artifacts * Revert bytes
1 parent 13f9c1f commit cda9876

File tree

7 files changed

+556
-328
lines changed

7 files changed

+556
-328
lines changed

.pre-commit-config.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ repos:
1818
args: [--show-error-codes]
1919
name: mypy auto-sklearn-util
2020
files: autosklearn/util
21+
- id: mypy
22+
args: [--show-error-codes]
23+
name: mypy auto-sklearn-evaluation
24+
files: autosklearn/evaluation
2125
- repo: https://gitlab.com/pycqa/flake8
2226
rev: 3.8.3
2327
hooks:

autosklearn/evaluation/__init__.py

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
from queue import Empty
88
import time
99
import traceback
10-
from typing import Dict, List, Optional, Tuple, Union
10+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast
1111

1212
from ConfigSpace import Configuration
1313
import numpy as np
1414
import pynisher
1515
from smac.runhistory.runhistory import RunInfo, RunValue
16+
from smac.stats.stats import Stats
1617
from smac.tae import StatusType, TAEAbortException
1718
from smac.tae.execute_func import AbstractTAFunc
1819

@@ -23,11 +24,17 @@
2324
import autosklearn.evaluation.train_evaluator
2425
import autosklearn.evaluation.test_evaluator
2526
import autosklearn.evaluation.util
26-
from autosklearn.util.logging_ import get_named_client_logger
27+
from autosklearn.evaluation.train_evaluator import TYPE_ADDITIONAL_INFO
28+
from autosklearn.util.backend import Backend
29+
from autosklearn.util.logging_ import PickableLoggerAdapter, get_named_client_logger
2730
from autosklearn.util.parallel import preload_modules
2831

2932

30-
def fit_predict_try_except_decorator(ta, queue, cost_for_crash, **kwargs):
33+
def fit_predict_try_except_decorator(
34+
ta: Callable,
35+
queue: multiprocessing.Queue,
36+
cost_for_crash: float,
37+
**kwargs: Any) -> None:
3138

3239
try:
3340
return ta(queue=queue, **kwargs)
@@ -66,7 +73,7 @@ def fit_predict_try_except_decorator(ta, queue, cost_for_crash, **kwargs):
6673
queue.close()
6774

6875

69-
def get_cost_of_crash(metric):
76+
def get_cost_of_crash(metric: Scorer) -> float:
7077

7178
# The metric must always be defined to extract optimum/worst
7279
if not isinstance(metric, Scorer):
@@ -85,8 +92,11 @@ def get_cost_of_crash(metric):
8592
return worst_possible_result
8693

8794

88-
def _encode_exit_status(exit_status):
95+
def _encode_exit_status(exit_status: Union[str, int, Type[BaseException]]
96+
) -> Union[str, int]:
8997
try:
98+
# If it can be dumped, then it is int
99+
exit_status = cast(int, exit_status)
90100
json.dumps(exit_status)
91101
return exit_status
92102
except (TypeError, OverflowError):
@@ -97,13 +107,31 @@ def _encode_exit_status(exit_status):
97107
# easier debugging of potential crashes
98108
class ExecuteTaFuncWithQueue(AbstractTAFunc):
99109

100-
def __init__(self, backend, autosklearn_seed, resampling_strategy, metric,
101-
cost_for_crash, abort_on_first_run_crash, port, pynisher_context,
102-
initial_num_run=1, stats=None,
103-
run_obj='quality', par_factor=1, scoring_functions=None,
104-
output_y_hat_optimization=True, include=None, exclude=None,
105-
memory_limit=None, disable_file_output=False, init_params=None,
106-
budget_type=None, ta=False, **resampling_strategy_args):
110+
def __init__(
111+
self,
112+
backend: Backend,
113+
autosklearn_seed: int,
114+
resampling_strategy: Union[str, BaseCrossValidator, _RepeatedSplits, BaseShuffleSplit],
115+
metric: Scorer,
116+
cost_for_crash: float,
117+
abort_on_first_run_crash: bool,
118+
port: int,
119+
pynisher_context: str,
120+
initial_num_run: int = 1,
121+
stats: Optional[Stats] = None,
122+
run_obj: str = 'quality',
123+
par_factor: int = 1,
124+
scoring_functions: Optional[List[Scorer]] = None,
125+
output_y_hat_optimization: bool = True,
126+
include: Optional[List[str]] = None,
127+
exclude: Optional[List[str]] = None,
128+
memory_limit: Optional[int] = None,
129+
disable_file_output: bool = False,
130+
init_params: Optional[Dict[str, Any]] = None,
131+
budget_type: Optional[str] = None,
132+
ta: Optional[Callable] = None,
133+
**resampling_strategy_args: Any,
134+
):
107135

108136
if resampling_strategy == 'holdout':
109137
eval_function = autosklearn.evaluation.train_evaluator.eval_holdout
@@ -180,7 +208,7 @@ def __init__(self, backend, autosklearn_seed, resampling_strategy, metric,
180208
self.port = port
181209
self.pynisher_context = pynisher_context
182210
if self.port is None:
183-
self.logger = logging.getLogger("TAE")
211+
self.logger: Union[logging.Logger, PickableLoggerAdapter] = logging.getLogger("TAE")
184212
else:
185213
self.logger = get_named_client_logger(
186214
name="TAE",
@@ -261,6 +289,10 @@ def run(
261289
instance_specific: Optional[str] = None,
262290
) -> Tuple[StatusType, float, float, Dict[str, Union[int, float, str, Dict, List, Tuple]]]:
263291

292+
# Additional information of each of the tae executions
293+
# Defined upfront for mypy
294+
additional_run_info: TYPE_ADDITIONAL_INFO = {}
295+
264296
context = multiprocessing.get_context(self.pynisher_context)
265297
preload_modules(context)
266298
queue = context.Queue()
@@ -272,7 +304,7 @@ def run(
272304
init_params.update(self.init_params)
273305

274306
if self.port is None:
275-
logger = logging.getLogger("pynisher")
307+
logger: Union[logging.Logger, PickableLoggerAdapter] = logging.getLogger("pynisher")
276308
else:
277309
logger = get_named_client_logger(
278310
name="pynisher",
@@ -320,11 +352,11 @@ def run(
320352
except Exception as e:
321353
exception_traceback = traceback.format_exc()
322354
error_message = repr(e)
323-
additional_info = {
355+
additional_run_info.update({
324356
'traceback': exception_traceback,
325357
'error': error_message
326-
}
327-
return StatusType.CRASHED, self.cost_for_crash, 0.0, additional_info
358+
})
359+
return StatusType.CRASHED, self.worst_possible_result, 0.0, additional_run_info
328360

329361
if obj.exit_status in (pynisher.TimeoutException, pynisher.MemorylimitException):
330362
# Even if the pynisher thinks that a timeout or memout occured,
@@ -359,7 +391,7 @@ def run(
359391
elif obj.exit_status is pynisher.MemorylimitException:
360392
status = StatusType.MEMOUT
361393
additional_run_info = {
362-
'error': 'Memout (used more than %d MB).' % self.memory_limit
394+
"error": "Memout (used more than {} MB).".format(self.memory_limit)
363395
}
364396
else:
365397
raise ValueError(obj.exit_status)

0 commit comments

Comments
 (0)