1
1
from copy import copy
2
+ from functools import partial
3
+ from typing import Any , List , Tuple , Union
2
4
5
+ from numpy import float64 , ndarray
3
6
from sortedcontainers import SortedDict , SortedSet
4
7
5
8
from adaptive .learner .base_learner import BaseLearner
@@ -15,17 +18,22 @@ class _IgnoreFirstArgument:
15
18
pickable.
16
19
"""
17
20
18
- def __init__ (self , function ) :
21
+ def __init__ (self , function : partial ) -> None :
19
22
self .function = function
20
23
21
- def __call__ (self , index_point , * args , ** kwargs ):
24
+ def __call__ (
25
+ self ,
26
+ index_point : Union [Tuple [int , int ], Tuple [int , float64 ], Tuple [int , ndarray ]],
27
+ * args ,
28
+ ** kwargs
29
+ ) -> Union [float64 , float ]:
22
30
index , point = index_point
23
31
return self .function (point , * args , ** kwargs )
24
32
25
- def __getstate__ (self ):
33
+ def __getstate__ (self ) -> partial :
26
34
return self .function
27
35
28
- def __setstate__ (self , function ) :
36
+ def __setstate__ (self , function : partial ) -> None :
29
37
self .__init__ (function )
30
38
31
39
@@ -56,7 +64,7 @@ class SequenceLearner(BaseLearner):
56
64
the added benefit of having results in the local kernel already.
57
65
"""
58
66
59
- def __init__ (self , function , sequence ) :
67
+ def __init__ (self , function : partial , sequence : Union [ range , ndarray ]) -> None :
60
68
self ._original_function = function
61
69
self .function = _IgnoreFirstArgument (function )
62
70
self ._to_do_indices = SortedSet ({i for i , _ in enumerate (sequence )})
@@ -65,7 +73,13 @@ def __init__(self, function, sequence):
65
73
self .data = SortedDict ()
66
74
self .pending_points = set ()
67
75
68
- def ask (self , n , tell_pending = True ):
76
+ def ask (
77
+ self , n : int , tell_pending : bool = True
78
+ ) -> Union [
79
+ Tuple [List [Tuple [int , float64 ]], List [float ]],
80
+ Tuple [List [Tuple [int , int ]], List [float ]],
81
+ Tuple [List [Tuple [int , ndarray ]], List [float ]],
82
+ ]:
69
83
indices = []
70
84
points = []
71
85
loss_improvements = []
@@ -83,17 +97,17 @@ def ask(self, n, tell_pending=True):
83
97
84
98
return points , loss_improvements
85
99
86
- def _get_data (self ):
100
+ def _get_data (self ) -> SortedDict :
87
101
return self .data
88
102
89
- def _set_data (self , data ) :
103
+ def _set_data (self , data : SortedDict ) -> None :
90
104
if data :
91
105
indices , values = zip (* data .items ())
92
106
# the points aren't used by tell, so we can safely pass None
93
107
points = [(i , None ) for i in indices ]
94
108
self .tell_many (points , values )
95
109
96
- def loss (self , real = True ):
110
+ def loss (self , real : bool = True ) -> float :
97
111
if not (self ._to_do_indices or self .pending_points ):
98
112
return 0
99
113
else :
@@ -105,13 +119,19 @@ def remove_unfinished(self):
105
119
self ._to_do_indices .add (i )
106
120
self .pending_points = set ()
107
121
108
- def tell (self , point , value ):
122
+ def tell (
123
+ self ,
124
+ point : Union [
125
+ Tuple [int , int ], Tuple [int , float64 ], Tuple [int , ndarray ], Tuple [int , None ]
126
+ ],
127
+ value : Union [float64 , float ],
128
+ ) -> None :
109
129
index , point = point
110
130
self .data [index ] = value
111
131
self .pending_points .discard (index )
112
132
self ._to_do_indices .discard (index )
113
133
114
- def tell_pending (self , point ) :
134
+ def tell_pending (self , point : Any ) -> None :
115
135
index , point = point
116
136
self .pending_points .add (index )
117
137
self ._to_do_indices .discard (index )
@@ -126,5 +146,5 @@ def result(self):
126
146
return list (self .data .values ())
127
147
128
148
@property
129
- def npoints (self ):
149
+ def npoints (self ) -> int :
130
150
return len (self .data )
0 commit comments