4
4
from contextlib import suppress
5
5
from functools import partial
6
6
from operator import itemgetter
7
+ from typing import Any , Callable , Dict , List , Set , Tuple , Union
7
8
8
9
import numpy as np
10
+ from numpy import float64 , int64
9
11
12
+ from adaptive .learner .average_learner import AverageLearner
10
13
from adaptive .learner .base_learner import BaseLearner
14
+ from adaptive .learner .learner1D import Learner1D
15
+ from adaptive .learner .learner2D import Learner2D
16
+ from adaptive .learner .learnerND import LearnerND
17
+ from adaptive .learner .sequence_learner import SequenceLearner , _IgnoreFirstArgument
11
18
from adaptive .notebook_integration import ensure_holoviews
12
19
from adaptive .utils import cache_latest , named_product , restore
13
20
14
21
15
- def dispatch (child_functions , arg ):
22
+ def dispatch (
23
+ child_functions : Union [List [Callable ], List [partial ], List [_IgnoreFirstArgument ]],
24
+ arg : Any ,
25
+ ) -> Union [int , float64 , float ]:
16
26
index , x = arg
17
27
return child_functions [index ](x )
18
28
@@ -68,7 +78,19 @@ class BalancingLearner(BaseLearner):
68
78
behave in an undefined way. Change the `strategy` in that case.
69
79
"""
70
80
71
- def __init__ (self , learners , * , cdims = None , strategy = "loss_improvements" ):
81
+ def __init__ (
82
+ self ,
83
+ learners : Union [
84
+ List [SequenceLearner ],
85
+ List [AverageLearner ],
86
+ List [Learner2D ],
87
+ List [Learner1D ],
88
+ List [LearnerND ],
89
+ ],
90
+ * ,
91
+ cdims = None ,
92
+ strategy = "loss_improvements"
93
+ ) -> None :
72
94
self .learners = learners
73
95
74
96
# Naively we would make 'function' a method, but this causes problems
@@ -89,21 +111,21 @@ def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):
89
111
self .strategy = strategy
90
112
91
113
@property
92
- def data (self ):
114
+ def data (self ) -> Dict [ Tuple [ int , int ], int ] :
93
115
data = {}
94
116
for i , l in enumerate (self .learners ):
95
117
data .update ({(i , p ): v for p , v in l .data .items ()})
96
118
return data
97
119
98
120
@property
99
- def pending_points (self ):
121
+ def pending_points (self ) -> Set [ Tuple [ int , int ]] :
100
122
pending_points = set ()
101
123
for i , l in enumerate (self .learners ):
102
124
pending_points .update ({(i , p ) for p in l .pending_points })
103
125
return pending_points
104
126
105
127
@property
106
- def npoints (self ):
128
+ def npoints (self ) -> int :
107
129
return sum (l .npoints for l in self .learners )
108
130
109
131
@property
@@ -135,7 +157,7 @@ def strategy(self, strategy):
135
157
' strategy="npoints", or strategy="cycle" is implemented.'
136
158
)
137
159
138
- def _ask_and_tell_based_on_loss_improvements (self , n ) :
160
+ def _ask_and_tell_based_on_loss_improvements (self , n : int ) -> Any :
139
161
selected = [] # tuples ((learner_index, point), loss_improvement)
140
162
total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
141
163
for _ in range (n ):
@@ -158,7 +180,13 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
158
180
points , loss_improvements = map (list , zip (* selected ))
159
181
return points , loss_improvements
160
182
161
- def _ask_and_tell_based_on_loss (self , n ):
183
+ def _ask_and_tell_based_on_loss (
184
+ self , n : int
185
+ ) -> Union [
186
+ Tuple [List [Tuple [int , float ]], List [float64 ]],
187
+ Tuple [List [Union [Tuple [int , int ], Tuple [int , float ]]], List [float ]],
188
+ Tuple [List [Tuple [int , int ]], List [float ]],
189
+ ]:
162
190
selected = [] # tuples ((learner_index, point), loss_improvement)
163
191
total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
164
192
for _ in range (n ):
@@ -179,7 +207,13 @@ def _ask_and_tell_based_on_loss(self, n):
179
207
points , loss_improvements = map (list , zip (* selected ))
180
208
return points , loss_improvements
181
209
182
- def _ask_and_tell_based_on_npoints (self , n ):
210
+ def _ask_and_tell_based_on_npoints (
211
+ self , n : int
212
+ ) -> Union [
213
+ Tuple [List [Union [Tuple [int64 , int ], Tuple [int64 , float ]]], List [float ]],
214
+ Tuple [List [Tuple [int64 , float ]], List [float64 ]],
215
+ Tuple [List [Tuple [int64 , int ]], List [float ]],
216
+ ]:
183
217
selected = [] # tuples ((learner_index, point), loss_improvement)
184
218
total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
185
219
for _ in range (n ):
@@ -195,7 +229,13 @@ def _ask_and_tell_based_on_npoints(self, n):
195
229
points , loss_improvements = map (list , zip (* selected ))
196
230
return points , loss_improvements
197
231
198
- def _ask_and_tell_based_on_cycle (self , n ):
232
+ def _ask_and_tell_based_on_cycle (
233
+ self , n : int
234
+ ) -> Union [
235
+ Tuple [List [Tuple [int , float ]], List [float64 ]],
236
+ Tuple [List [Union [Tuple [int , int ], Tuple [int , float ]]], List [float ]],
237
+ Tuple [List [Tuple [int , int ]], List [float ]],
238
+ ]:
199
239
points , loss_improvements = [], []
200
240
for _ in range (n ):
201
241
index = next (self ._cycle )
@@ -206,7 +246,7 @@ def _ask_and_tell_based_on_cycle(self, n):
206
246
207
247
return points , loss_improvements
208
248
209
- def ask (self , n , tell_pending = True ):
249
+ def ask (self , n : int , tell_pending : bool = True ) -> Any :
210
250
"""Chose points for learners."""
211
251
if n == 0 :
212
252
return [], []
@@ -217,20 +257,24 @@ def ask(self, n, tell_pending=True):
217
257
else :
218
258
return self ._ask_and_tell (n )
219
259
220
- def tell (self , x , y ):
260
+ def tell (
261
+ self , x : Any , y : Union [int , float64 , float , Tuple [int , int ], Tuple [int64 , int ]]
262
+ ) -> None :
221
263
index , x = x
222
264
self ._ask_cache .pop (index , None )
223
265
self ._loss .pop (index , None )
224
266
self ._pending_loss .pop (index , None )
225
267
self .learners [index ].tell (x , y )
226
268
227
- def tell_pending (self , x ) :
269
+ def tell_pending (self , x : Any ) -> None :
228
270
index , x = x
229
271
self ._ask_cache .pop (index , None )
230
272
self ._loss .pop (index , None )
231
273
self .learners [index ].tell_pending (x )
232
274
233
- def _losses (self , real = True ):
275
+ def _losses (
276
+ self , real : bool = True
277
+ ) -> Union [List [float ], List [float64 ], List [Union [float , float64 ]]]:
234
278
losses = []
235
279
loss_dict = self ._loss if real else self ._pending_loss
236
280
@@ -242,7 +286,7 @@ def _losses(self, real=True):
242
286
return losses
243
287
244
288
@cache_latest
245
- def loss (self , real = True ):
289
+ def loss (self , real : bool = True ) -> Union [ float64 , float ] :
246
290
losses = self ._losses (real )
247
291
return max (losses )
248
292
@@ -372,7 +416,7 @@ def from_product(cls, f, learner_type, learner_kwargs, combos):
372
416
learners .append (learner )
373
417
return cls (learners , cdims = arguments )
374
418
375
- def save (self , fname , compress = True ):
419
+ def save (self , fname : Callable , compress : bool = True ) -> None :
376
420
"""Save the data of the child learners into pickle files
377
421
in a directory.
378
422
@@ -410,7 +454,7 @@ def save(self, fname, compress=True):
410
454
for l in self .learners :
411
455
l .save (fname (l ), compress = compress )
412
456
413
- def load (self , fname , compress = True ):
457
+ def load (self , fname : Callable , compress : bool = True ) -> None :
414
458
"""Load the data of the child learners from pickle files
415
459
in a directory.
416
460
0 commit comments