Skip to content

Commit d8d2fba

Browse files
committed
add type annotations for adaptive/learner/sequence_learner.py
1 parent fa0a0ef commit d8d2fba

File tree

1 file changed

+32
-12
lines changed

1 file changed

+32
-12
lines changed

adaptive/learner/sequence_learner.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from copy import copy
2+
from functools import partial
3+
from typing import Any, List, Tuple, Union
24

5+
from numpy import float64, ndarray
36
from sortedcontainers import SortedDict, SortedSet
47

58
from adaptive.learner.base_learner import BaseLearner
@@ -15,17 +18,22 @@ class _IgnoreFirstArgument:
1518
pickable.
1619
"""
1720

18-
def __init__(self, function):
21+
def __init__(self, function: partial) -> None:
1922
self.function = function
2023

21-
def __call__(self, index_point, *args, **kwargs):
24+
def __call__(
25+
self,
26+
index_point: Union[Tuple[int, int], Tuple[int, float64], Tuple[int, ndarray]],
27+
*args,
28+
**kwargs
29+
) -> Union[float64, float]:
2230
index, point = index_point
2331
return self.function(point, *args, **kwargs)
2432

25-
def __getstate__(self):
33+
def __getstate__(self) -> partial:
2634
return self.function
2735

28-
def __setstate__(self, function):
36+
def __setstate__(self, function: partial) -> None:
2937
self.__init__(function)
3038

3139

@@ -56,7 +64,7 @@ class SequenceLearner(BaseLearner):
5664
the added benefit of having results in the local kernel already.
5765
"""
5866

59-
def __init__(self, function, sequence):
67+
def __init__(self, function: partial, sequence: Union[range, ndarray]) -> None:
6068
self._original_function = function
6169
self.function = _IgnoreFirstArgument(function)
6270
self._to_do_indices = SortedSet({i for i, _ in enumerate(sequence)})
@@ -65,7 +73,13 @@ def __init__(self, function, sequence):
6573
self.data = SortedDict()
6674
self.pending_points = set()
6775

68-
def ask(self, n, tell_pending=True):
76+
def ask(
77+
self, n: int, tell_pending: bool = True
78+
) -> Union[
79+
Tuple[List[Tuple[int, float64]], List[float]],
80+
Tuple[List[Tuple[int, int]], List[float]],
81+
Tuple[List[Tuple[int, ndarray]], List[float]],
82+
]:
6983
indices = []
7084
points = []
7185
loss_improvements = []
@@ -83,17 +97,17 @@ def ask(self, n, tell_pending=True):
8397

8498
return points, loss_improvements
8599

86-
def _get_data(self):
100+
def _get_data(self) -> SortedDict:
87101
return self.data
88102

89-
def _set_data(self, data):
103+
def _set_data(self, data: SortedDict) -> None:
90104
if data:
91105
indices, values = zip(*data.items())
92106
# the points aren't used by tell, so we can safely pass None
93107
points = [(i, None) for i in indices]
94108
self.tell_many(points, values)
95109

96-
def loss(self, real=True):
110+
def loss(self, real: bool = True) -> float:
97111
if not (self._to_do_indices or self.pending_points):
98112
return 0
99113
else:
@@ -105,13 +119,19 @@ def remove_unfinished(self):
105119
self._to_do_indices.add(i)
106120
self.pending_points = set()
107121

108-
def tell(self, point, value):
122+
def tell(
123+
self,
124+
point: Union[
125+
Tuple[int, int], Tuple[int, float64], Tuple[int, ndarray], Tuple[int, None]
126+
],
127+
value: Union[float64, float],
128+
) -> None:
109129
index, point = point
110130
self.data[index] = value
111131
self.pending_points.discard(index)
112132
self._to_do_indices.discard(index)
113133

114-
def tell_pending(self, point):
134+
def tell_pending(self, point: Any) -> None:
115135
index, point = point
116136
self.pending_points.add(index)
117137
self._to_do_indices.discard(index)
@@ -126,5 +146,5 @@ def result(self):
126146
return list(self.data.values())
127147

128148
@property
129-
def npoints(self):
149+
def npoints(self) -> int:
130150
return len(self.data)

0 commit comments

Comments
 (0)