4
4
from contextlib import suppress
5
5
from functools import partial
6
6
from operator import itemgetter
7
- from typing import Any , Callable , Dict , List , Set , Tuple , Union
7
+ from typing import (
8
+ Any ,
9
+ Callable ,
10
+ Dict ,
11
+ List ,
12
+ Literal ,
13
+ Optional ,
14
+ Sequence ,
15
+ Set ,
16
+ Tuple ,
17
+ Union ,
18
+ )
8
19
9
20
import numpy as np
10
21
@@ -18,6 +29,14 @@ def dispatch(child_functions: List[Callable], arg: Any) -> Union[Any]:
18
29
return child_functions [index ](x )
19
30
20
31
32
+ STRATEGY_TYPE = Literal ["loss_improvements" , "loss" , "npoints" , "cycle" ]
33
+
34
+ CDIMS_TYPE = Union [
35
+ Sequence [Dict [str , Any ]],
36
+ Tuple [Sequence [str ], Sequence [Tuple [Any , ...]]],
37
+ ]
38
+
39
+
21
40
class BalancingLearner (BaseLearner ):
22
41
r"""Choose the optimal points from a set of learners.
23
42
@@ -70,7 +89,11 @@ class BalancingLearner(BaseLearner):
70
89
"""
71
90
72
91
def __init__ (
73
- self , learners : List [BaseLearner ], * , cdims = None , strategy = "loss_improvements"
92
+ self ,
93
+ learners : List [BaseLearner ],
94
+ * ,
95
+ cdims : Optional [CDIMS_TYPE ] = None ,
96
+ strategy : STRATEGY_TYPE = "loss_improvements"
74
97
) -> None :
75
98
self .learners = learners
76
99
@@ -89,7 +112,7 @@ def __init__(
89
112
"A BalacingLearner can handle only one type" " of learners."
90
113
)
91
114
92
- self .strategy = strategy
115
+ self .strategy : STRATEGY_TYPE = strategy
93
116
94
117
@property
95
118
def data (self ) -> Dict [Tuple [int , Any ], Any ]:
@@ -110,7 +133,7 @@ def npoints(self) -> int:
110
133
return sum (l .npoints for l in self .learners )
111
134
112
135
@property
113
- def strategy (self ):
136
+ def strategy (self ) -> STRATEGY_TYPE :
114
137
"""Can be either 'loss_improvements' (default), 'loss', 'npoints', or
115
138
'cycle'. The points that the `BalancingLearner` choses can be either
116
139
based on: the best 'loss_improvements', the smallest total 'loss' of
@@ -121,7 +144,7 @@ def strategy(self):
121
144
return self ._strategy
122
145
123
146
@strategy .setter
124
- def strategy (self , strategy ) :
147
+ def strategy (self , strategy : STRATEGY_TYPE ) -> None :
125
148
self ._strategy = strategy
126
149
if strategy == "loss_improvements" :
127
150
self ._ask_and_tell = self ._ask_and_tell_based_on_loss_improvements
@@ -255,11 +278,16 @@ def _losses(self, real: bool = True) -> List[float]:
255
278
return losses
256
279
257
280
@cache_latest
258
- def loss (self , real : bool = True ) -> Union [ float ] :
281
+ def loss (self , real : bool = True ) -> float :
259
282
losses = self ._losses (real )
260
283
return max (losses )
261
284
262
- def plot (self , cdims = None , plotter = None , dynamic = True ):
285
+ def plot (
286
+ self ,
287
+ cdims : Optional [CDIMS_TYPE ] = None ,
288
+ plotter : Optional [Callable [[BaseLearner ], Any ]] = None ,
289
+ dynamic : bool = True ,
290
+ ):
263
291
"""Returns a DynamicMap with sliders.
264
292
265
293
Parameters
@@ -332,14 +360,18 @@ def plot_function(*args):
332
360
vals = {d .name : d .values for d in dm .dimensions () if d .values }
333
361
return hv .HoloMap (dm .select (** vals ))
334
362
335
- def remove_unfinished (self ):
363
+ def remove_unfinished (self ) -> None :
336
364
"""Remove uncomputed data from the learners."""
337
365
for learner in self .learners :
338
366
learner .remove_unfinished ()
339
367
340
368
@classmethod
341
369
def from_product (
342
- cls , f , learner_type , learner_kwargs , combos
370
+ cls ,
371
+ f ,
372
+ learner_type : BaseLearner ,
373
+ learner_kwargs : Dict [str , Any ],
374
+ combos : Dict [str , Iterable [Any ]],
343
375
) -> "BalancingLearner" :
344
376
"""Create a `BalancingLearner` with learners of all combinations of
345
377
named variables’ values. The `cdims` will be set correctly, so calling
@@ -387,7 +419,11 @@ def from_product(
387
419
learners .append (learner )
388
420
return cls (learners , cdims = arguments )
389
421
390
- def save (self , fname : Callable , compress : bool = True ) -> None :
422
+ def save (
423
+ self ,
424
+ fname : Union [Callable [[BaseLearner ], str ], Sequence [str ]],
425
+ compress : bool = True ,
426
+ ) -> None :
391
427
"""Save the data of the child learners into pickle files
392
428
in a directory.
393
429
@@ -425,7 +461,11 @@ def save(self, fname: Callable, compress: bool = True) -> None:
425
461
for l in self .learners :
426
462
l .save (fname (l ), compress = compress )
427
463
428
- def load (self , fname : Callable , compress : bool = True ) -> None :
464
+ def load (
465
+ self ,
466
+ fname : Union [Callable [[BaseLearner ], str ], Sequence [str ]],
467
+ compress : bool = True ,
468
+ ) -> None :
429
469
"""Load the data of the child learners from pickle files
430
470
in a directory.
431
471
@@ -449,20 +489,20 @@ def load(self, fname: Callable, compress: bool = True) -> None:
449
489
for l in self .learners :
450
490
l .load (fname (l ), compress = compress )
451
491
452
- def _get_data (self ):
492
+ def _get_data (self ) -> List [ Any ] :
453
493
return [l ._get_data () for l in self .learners ]
454
494
455
- def _set_data (self , data ):
495
+ def _set_data (self , data : List [ Any ] ):
456
496
for l , _data in zip (self .learners , data ):
457
497
l ._set_data (_data )
458
498
459
- def __getstate__ (self ):
499
+ def __getstate__ (self ) -> Tuple [ List [ BaseLearner ], CDIMS_TYPE , STRATEGY_TYPE ] :
460
500
return (
461
501
self .learners ,
462
502
self ._cdims_default ,
463
503
self .strategy ,
464
504
)
465
505
466
- def __setstate__ (self , state ):
506
+ def __setstate__ (self , state : Tuple [ List [ BaseLearner ], CDIMS_TYPE , STRATEGY_TYPE ] ):
467
507
learners , cdims , strategy = state
468
508
self .__init__ (learners , cdims = cdims , strategy = strategy )
0 commit comments