13
13
from adaptive .utils import cache_latest , named_product , restore
14
14
15
15
16
- def dispatch (
17
- child_functions : List [Callable ], arg : Any ,
18
- ) -> Union [int , np .float64 , float ]:
16
+ def dispatch (child_functions : List [Callable ], arg : Any ,) -> Union [Any ]:
19
17
index , x = arg
20
18
return child_functions [index ](x )
21
19
@@ -94,14 +92,14 @@ def __init__(
94
92
self .strategy = strategy
95
93
96
94
@property
97
- def data (self ) -> Dict [Tuple [int , int ], int ]:
95
+ def data (self ) -> Dict [Tuple [int , Any ], Any ]:
98
96
data = {}
99
97
for i , l in enumerate (self .learners ):
100
98
data .update ({(i , p ): v for p , v in l .data .items ()})
101
99
return data
102
100
103
101
@property
104
- def pending_points (self ) -> Set [Tuple [int , int ]]:
102
+ def pending_points (self ) -> Set [Tuple [int , Any ]]:
105
103
pending_points = set ()
106
104
for i , l in enumerate (self .learners ):
107
105
pending_points .update ({(i , p ) for p in l .pending_points })
@@ -140,7 +138,9 @@ def strategy(self, strategy):
140
138
' strategy="npoints", or strategy="cycle" is implemented.'
141
139
)
142
140
143
- def _ask_and_tell_based_on_loss_improvements (self , n : int ) -> Any :
141
+ def _ask_and_tell_based_on_loss_improvements (
142
+ self , n : int
143
+ ) -> Tuple [List [Tuple [int , Any ]], List [float ]]:
144
144
selected = [] # tuples ((learner_index, point), loss_improvement)
145
145
total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
146
146
for _ in range (n ):
@@ -165,11 +165,7 @@ def _ask_and_tell_based_on_loss_improvements(self, n: int) -> Any:
165
165
166
166
def _ask_and_tell_based_on_loss (
167
167
self , n : int
168
- ) -> Union [
169
- Tuple [List [Tuple [int , float ]], List [np .float64 ]],
170
- Tuple [List [Union [Tuple [int , int ], Tuple [int , float ]]], List [float ]],
171
- Tuple [List [Tuple [int , int ]], List [float ]],
172
- ]:
168
+ ) -> Tuple [List [Tuple [int , Any ]], List [float ]]:
173
169
selected = [] # tuples ((learner_index, point), loss_improvement)
174
170
total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
175
171
for _ in range (n ):
@@ -192,11 +188,7 @@ def _ask_and_tell_based_on_loss(
192
188
193
189
def _ask_and_tell_based_on_npoints (
194
190
self , n : int
195
- ) -> Union [
196
- Tuple [List [Union [Tuple [np .int64 , int ], Tuple [np .int64 , float ]]], List [float ]],
197
- Tuple [List [Tuple [np .int64 , float ]], List [np .float64 ]],
198
- Tuple [List [Tuple [np .int64 , int ]], List [float ]],
199
- ]:
191
+ ) -> Tuple [List [Tuple [int , Any ]], List [float ]]:
200
192
selected = [] # tuples ((learner_index, point), loss_improvement)
201
193
total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
202
194
for _ in range (n ):
@@ -214,11 +206,7 @@ def _ask_and_tell_based_on_npoints(
214
206
215
207
def _ask_and_tell_based_on_cycle (
216
208
self , n : int
217
- ) -> Union [
218
- Tuple [List [Tuple [int , float ]], List [np .float64 ]],
219
- Tuple [List [Union [Tuple [int , int ], Tuple [int , float ]]], List [float ]],
220
- Tuple [List [Tuple [int , int ]], List [float ]],
221
- ]:
209
+ ) -> Tuple [List [Tuple [int , Any ]], List [float ]]:
222
210
points , loss_improvements = [], []
223
211
for _ in range (n ):
224
212
index = next (self ._cycle )
@@ -229,7 +217,9 @@ def _ask_and_tell_based_on_cycle(
229
217
230
218
return points , loss_improvements
231
219
232
- def ask (self , n : int , tell_pending : bool = True ) -> Tuple [List [Any ], List [float ]]:
220
+ def ask (
221
+ self , n : int , tell_pending : bool = True
222
+ ) -> Tuple [List [Tuple [int , Any ]], List [float ]]:
233
223
"""Chose points for learners."""
234
224
if n == 0 :
235
225
return [], []
@@ -240,26 +230,20 @@ def ask(self, n: int, tell_pending: bool = True) -> Tuple[List[Any], List[float]
240
230
else :
241
231
return self ._ask_and_tell (n )
242
232
243
- def tell (
244
- self ,
245
- x : Any ,
246
- y : Union [int , np .float64 , float , Tuple [int , int ], Tuple [np .int64 , int ]],
247
- ) -> None :
233
+ def tell (self , x : Tuple [int , Any ], y : Any ,) -> None :
248
234
index , x = x
249
235
self ._ask_cache .pop (index , None )
250
236
self ._loss .pop (index , None )
251
237
self ._pending_loss .pop (index , None )
252
238
self .learners [index ].tell (x , y )
253
239
254
- def tell_pending (self , x : Any ) -> None :
240
+ def tell_pending (self , x : Tuple [ int , Any ] ) -> None :
255
241
index , x = x
256
242
self ._ask_cache .pop (index , None )
257
243
self ._loss .pop (index , None )
258
244
self .learners [index ].tell_pending (x )
259
245
260
- def _losses (
261
- self , real : bool = True
262
- ) -> Union [List [float ], List [np .float64 ], List [Union [float , np .float64 ]]]:
246
+ def _losses (self , real : bool = True ) -> List [float ]:
263
247
losses = []
264
248
loss_dict = self ._loss if real else self ._pending_loss
265
249
@@ -271,7 +255,7 @@ def _losses(
271
255
return losses
272
256
273
257
@cache_latest
274
- def loss (self , real : bool = True ) -> Union [np . float64 , float ]:
258
+ def loss (self , real : bool = True ) -> Union [float ]:
275
259
losses = self ._losses (real )
276
260
return max (losses )
277
261
0 commit comments