Skip to content

Commit 341d313

Browse files
committed
add type annotations for adaptive/learner/skopt_learner.py
1 parent e394995 commit 341d313

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

adaptive/learner/skopt_learner.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import collections
2+
from typing import Callable, List, Tuple, Union
23

34
import numpy as np
5+
from numpy import float64
46
from skopt import Optimizer
57

68
from adaptive.learner.base_learner import BaseLearner
@@ -23,13 +25,15 @@ class SKOptLearner(Optimizer, BaseLearner):
2325
Arguments to pass to ``skopt.Optimizer``.
2426
"""
2527

26-
def __init__(self, function, **kwargs):
28+
def __init__(self, function: Callable, **kwargs) -> None:
2729
self.function = function
2830
self.pending_points = set()
2931
self.data = collections.OrderedDict()
3032
super().__init__(**kwargs)
3133

32-
def tell(self, x, y, fit=True):
34+
def tell(
35+
self, x: Union[float64, List[float64]], y: float64, fit: bool = True
36+
) -> None:
3337
if isinstance(x, collections.abc.Iterable):
3438
self.pending_points.discard(tuple(x))
3539
self.data[tuple(x)] = y
@@ -48,7 +52,7 @@ def remove_unfinished(self):
4852
pass
4953

5054
@cache_latest
51-
def loss(self, real=True):
55+
def loss(self, real: bool = True) -> Union[float64, float]:
5256
if not self.models:
5357
return np.inf
5458
else:
@@ -58,7 +62,14 @@ def loss(self, real=True):
5862
# estimator of loss, but it is the cheapest.
5963
return 1 - model.score(self.Xi, self.yi)
6064

61-
def ask(self, n, tell_pending=True):
65+
def ask(
66+
self, n: int, tell_pending: bool = True
67+
) -> Union[
68+
Tuple[List[float64], List[float64]],
69+
Tuple[List[List[float64]], List[float64]],
70+
Tuple[List[List[float64]], List[float]],
71+
Tuple[List[float64], List[float]],
72+
]:
6273
if not tell_pending:
6374
raise NotImplementedError(
6475
"Asking points is an irreversible "
@@ -72,7 +83,7 @@ def ask(self, n, tell_pending=True):
7283
return [p[0] for p in points], [self.loss() / n] * n
7384

7485
@property
75-
def npoints(self):
86+
def npoints(self) -> int:
7687
"""Number of evaluated points."""
7788
return len(self.Xi)
7889

0 commit comments

Comments
 (0)