1
1
from copy import copy
2
- from functools import partial
3
- from typing import Any , List , Tuple , Union
2
+ from typing import Any , Callable , List , Sequence , Tuple , Union
4
3
5
4
import numpy as np
6
5
from sortedcontainers import SortedDict , SortedSet
@@ -18,22 +17,19 @@ class _IgnoreFirstArgument:
18
17
pickable.
19
18
"""
20
19
21
- def __init__ (self , function : partial ) -> None :
20
+ def __init__ (self , function : Callable ) -> None :
22
21
self .function = function
23
22
24
23
def __call__ (
25
- self ,
26
- index_point : Union [Tuple [int , int ], Tuple [int , float ], Tuple [int , np .ndarray ]],
27
- * args ,
28
- ** kwargs
24
+ self , index_point : Tuple [int , Union [float , np .ndarray ]], * args , ** kwargs
29
25
) -> float :
30
26
index , point = index_point
31
27
return self .function (point , * args , ** kwargs )
32
28
33
- def __getstate__ (self ) -> partial :
29
+ def __getstate__ (self ) -> Callable :
34
30
return self .function
35
31
36
- def __setstate__ (self , function : partial ) -> None :
32
+ def __setstate__ (self , function : Callable ) -> None :
37
33
self .__init__ (function )
38
34
39
35
@@ -64,7 +60,7 @@ class SequenceLearner(BaseLearner):
64
60
the added benefit of having results in the local kernel already.
65
61
"""
66
62
67
- def __init__ (self , function : partial , sequence : Union [ range , np . ndarray ] ) -> None :
63
+ def __init__ (self , function : Callable , sequence : Sequence ) -> None :
68
64
self ._original_function = function
69
65
self .function = _IgnoreFirstArgument (function )
70
66
self ._to_do_indices = SortedSet ({i for i , _ in enumerate (sequence )})
@@ -73,13 +69,7 @@ def __init__(self, function: partial, sequence: Union[range, np.ndarray]) -> Non
73
69
self .data = SortedDict ()
74
70
self .pending_points = set ()
75
71
76
- def ask (
77
- self , n : int , tell_pending : bool = True
78
- ) -> Union [
79
- Tuple [List [Tuple [int , float ]], List [float ]],
80
- Tuple [List [Tuple [int , int ]], List [float ]],
81
- Tuple [List [Tuple [int , np .ndarray ]], List [float ]],
82
- ]:
72
+ def ask (self , n : int , tell_pending : bool = True ) -> Tuple [Any , List [float ]]:
83
73
indices = []
84
74
points = []
85
75
loss_improvements = []
@@ -119,16 +109,7 @@ def remove_unfinished(self):
119
109
self ._to_do_indices .add (i )
120
110
self .pending_points = set ()
121
111
122
- def tell (
123
- self ,
124
- point : Union [
125
- Tuple [int , int ],
126
- Tuple [int , float ],
127
- Tuple [int , np .ndarray ],
128
- Tuple [int , None ],
129
- ],
130
- value : float ,
131
- ) -> None :
112
+ def tell (self , point : Tuple [int , Any ], value : Any ,) -> None :
132
113
index , point = point
133
114
self .data [index ] = value
134
115
self .pending_points .discard (index )
0 commit comments