Skip to content

Commit 0a707dd

Browse files
committed
type hint fixes for adaptive/learner/sequence_learner.py
1 parent de4a8c7 commit 0a707dd

File tree

1 file changed

+8
-27
lines changed

1 file changed

+8
-27
lines changed

adaptive/learner/sequence_learner.py

Lines changed: 8 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from copy import copy
2-
from functools import partial
3-
from typing import Any, List, Tuple, Union
2+
from typing import Any, Callable, List, Sequence, Tuple, Union
43

54
import numpy as np
65
from sortedcontainers import SortedDict, SortedSet
@@ -18,22 +17,19 @@ class _IgnoreFirstArgument:
1817
pickable.
1918
"""
2019

21-
def __init__(self, function: partial) -> None:
20+
def __init__(self, function: Callable) -> None:
2221
self.function = function
2322

2423
def __call__(
25-
self,
26-
index_point: Union[Tuple[int, int], Tuple[int, float], Tuple[int, np.ndarray]],
27-
*args,
28-
**kwargs
24+
self, index_point: Tuple[int, Union[float, np.ndarray]], *args, **kwargs
2925
) -> float:
3026
index, point = index_point
3127
return self.function(point, *args, **kwargs)
3228

33-
def __getstate__(self) -> partial:
29+
def __getstate__(self) -> Callable:
3430
return self.function
3531

36-
def __setstate__(self, function: partial) -> None:
32+
def __setstate__(self, function: Callable) -> None:
3733
self.__init__(function)
3834

3935

@@ -64,7 +60,7 @@ class SequenceLearner(BaseLearner):
6460
the added benefit of having results in the local kernel already.
6561
"""
6662

67-
def __init__(self, function: partial, sequence: Union[range, np.ndarray]) -> None:
63+
def __init__(self, function: Callable, sequence: Sequence) -> None:
6864
self._original_function = function
6965
self.function = _IgnoreFirstArgument(function)
7066
self._to_do_indices = SortedSet({i for i, _ in enumerate(sequence)})
@@ -73,13 +69,7 @@ def __init__(self, function: partial, sequence: Union[range, np.ndarray]) -> Non
7369
self.data = SortedDict()
7470
self.pending_points = set()
7571

76-
def ask(
77-
self, n: int, tell_pending: bool = True
78-
) -> Union[
79-
Tuple[List[Tuple[int, float]], List[float]],
80-
Tuple[List[Tuple[int, int]], List[float]],
81-
Tuple[List[Tuple[int, np.ndarray]], List[float]],
82-
]:
72+
def ask(self, n: int, tell_pending: bool = True) -> Tuple[Any, List[float]]:
8373
indices = []
8474
points = []
8575
loss_improvements = []
@@ -119,16 +109,7 @@ def remove_unfinished(self):
119109
self._to_do_indices.add(i)
120110
self.pending_points = set()
121111

122-
def tell(
123-
self,
124-
point: Union[
125-
Tuple[int, int],
126-
Tuple[int, float],
127-
Tuple[int, np.ndarray],
128-
Tuple[int, None],
129-
],
130-
value: float,
131-
) -> None:
112+
def tell(self, point: Tuple[int, Any], value: Any,) -> None:
132113
index, point = point
133114
self.data[index] = value
134115
self.pending_points.discard(index)

0 commit comments

Comments
 (0)