1
+ from __future__ import annotations
2
+
1
3
import itertools
2
4
import numbers
3
5
from collections import defaultdict
4
6
from collections .abc import Iterable
5
7
from contextlib import suppress
6
8
from functools import partial
7
9
from operator import itemgetter
8
- from typing import (
9
- Any ,
10
- Callable ,
11
- Dict ,
12
- List ,
13
- Literal ,
14
- Optional ,
15
- Sequence ,
16
- Set ,
17
- Tuple ,
18
- Union ,
19
- )
10
+ from typing import Any , Callable , Dict , Literal , Sequence , Tuple , Union
20
11
21
12
import numpy as np
22
13
25
16
from adaptive .utils import cache_latest , named_product , restore
26
17
27
18
28
- def dispatch (child_functions : List [Callable ], arg : Any ) -> Union [ Any ] :
19
+ def dispatch (child_functions : list [Callable ], arg : Any ) -> Any :
29
20
index , x = arg
30
21
return child_functions [index ](x )
31
22
@@ -91,9 +82,9 @@ class BalancingLearner(BaseLearner):
91
82
92
83
def __init__ (
93
84
self ,
94
- learners : List [BaseLearner ],
85
+ learners : list [BaseLearner ],
95
86
* ,
96
- cdims : Optional [ CDIMS_TYPE ] = None ,
87
+ cdims : CDIMS_TYPE | None = None ,
97
88
strategy : STRATEGY_TYPE = "loss_improvements" ,
98
89
) -> None :
99
90
self .learners = learners
@@ -116,14 +107,14 @@ def __init__(
116
107
self .strategy : STRATEGY_TYPE = strategy
117
108
118
109
@property
119
- def data (self ) -> Dict [ Tuple [int , Any ], Any ]:
110
+ def data (self ) -> dict [ tuple [int , Any ], Any ]:
120
111
data = {}
121
112
for i , l in enumerate (self .learners ):
122
113
data .update ({(i , p ): v for p , v in l .data .items ()})
123
114
return data
124
115
125
116
@property
126
- def pending_points (self ) -> Set [ Tuple [int , Any ]]:
117
+ def pending_points (self ) -> set [ tuple [int , Any ]]:
127
118
pending_points = set ()
128
119
for i , l in enumerate (self .learners ):
129
120
pending_points .update ({(i , p ) for p in l .pending_points })
@@ -173,7 +164,7 @@ def strategy(self, strategy: STRATEGY_TYPE) -> None:
173
164
174
165
def _ask_and_tell_based_on_loss_improvements (
175
166
self , n : int
176
- ) -> Tuple [ List [ Tuple [int , Any ]], List [float ]]:
167
+ ) -> tuple [ list [ tuple [int , Any ]], list [float ]]:
177
168
selected = [] # tuples ((learner_index, point), loss_improvement)
178
169
total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
179
170
for _ in range (n ):
@@ -198,7 +189,7 @@ def _ask_and_tell_based_on_loss_improvements(
198
189
199
190
def _ask_and_tell_based_on_loss (
200
191
self , n : int
201
- ) -> Tuple [ List [ Tuple [int , Any ]], List [float ]]:
192
+ ) -> tuple [ list [ tuple [int , Any ]], list [float ]]:
202
193
selected = [] # tuples ((learner_index, point), loss_improvement)
203
194
total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
204
195
for _ in range (n ):
@@ -221,7 +212,7 @@ def _ask_and_tell_based_on_loss(
221
212
222
213
def _ask_and_tell_based_on_npoints (
223
214
self , n : numbers .Integral
224
- ) -> Tuple [ List [ Tuple [numbers .Integral , Any ]], List [float ]]:
215
+ ) -> tuple [ list [ tuple [numbers .Integral , Any ]], list [float ]]:
225
216
selected = [] # tuples ((learner_index, point), loss_improvement)
226
217
total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
227
218
for _ in range (n ):
@@ -239,7 +230,7 @@ def _ask_and_tell_based_on_npoints(
239
230
240
231
def _ask_and_tell_based_on_cycle (
241
232
self , n : int
242
- ) -> Tuple [ List [ Tuple [numbers .Integral , Any ]], List [float ]]:
233
+ ) -> tuple [ list [ tuple [numbers .Integral , Any ]], list [float ]]:
243
234
points , loss_improvements = [], []
244
235
for _ in range (n ):
245
236
index = next (self ._cycle )
@@ -252,7 +243,7 @@ def _ask_and_tell_based_on_cycle(
252
243
253
244
def ask (
254
245
self , n : int , tell_pending : bool = True
255
- ) -> Tuple [ List [ Tuple [numbers .Integral , Any ]], List [float ]]:
246
+ ) -> tuple [ list [ tuple [numbers .Integral , Any ]], list [float ]]:
256
247
"""Chose points for learners."""
257
248
if n == 0 :
258
249
return [], []
@@ -263,20 +254,20 @@ def ask(
263
254
else :
264
255
return self ._ask_and_tell (n )
265
256
266
- def tell (self , x : Tuple [numbers .Integral , Any ], y : Any ) -> None :
257
+ def tell (self , x : tuple [numbers .Integral , Any ], y : Any ) -> None :
267
258
index , x = x
268
259
self ._ask_cache .pop (index , None )
269
260
self ._loss .pop (index , None )
270
261
self ._pending_loss .pop (index , None )
271
262
self .learners [index ].tell (x , y )
272
263
273
- def tell_pending (self , x : Tuple [numbers .Integral , Any ]) -> None :
264
+ def tell_pending (self , x : tuple [numbers .Integral , Any ]) -> None :
274
265
index , x = x
275
266
self ._ask_cache .pop (index , None )
276
267
self ._loss .pop (index , None )
277
268
self .learners [index ].tell_pending (x )
278
269
279
- def _losses (self , real : bool = True ) -> List [float ]:
270
+ def _losses (self , real : bool = True ) -> list [float ]:
280
271
losses = []
281
272
loss_dict = self ._loss if real else self ._pending_loss
282
273
@@ -294,8 +285,8 @@ def loss(self, real: bool = True) -> float:
294
285
295
286
def plot (
296
287
self ,
297
- cdims : Optional [ CDIMS_TYPE ] = None ,
298
- plotter : Optional [ Callable [[BaseLearner ], Any ]] = None ,
288
+ cdims : CDIMS_TYPE | None = None ,
289
+ plotter : Callable [[BaseLearner ], Any ] | None = None ,
299
290
dynamic : bool = True ,
300
291
):
301
292
"""Returns a DynamicMap with sliders.
@@ -380,9 +371,9 @@ def from_product(
380
371
cls ,
381
372
f ,
382
373
learner_type : BaseLearner ,
383
- learner_kwargs : Dict [str , Any ],
384
- combos : Dict [str , Sequence [Any ]],
385
- ) -> " BalancingLearner" :
374
+ learner_kwargs : dict [str , Any ],
375
+ combos : dict [str , Sequence [Any ]],
376
+ ) -> BalancingLearner :
386
377
"""Create a `BalancingLearner` with learners of all combinations of
387
378
named variables’ values. The `cdims` will be set correctly, so calling
388
379
`learner.plot` will be a `holoviews.core.HoloMap` with the correct labels.
@@ -431,7 +422,7 @@ def from_product(
431
422
432
423
def save (
433
424
self ,
434
- fname : Union [ Callable [[BaseLearner ], str ], Sequence [str ] ],
425
+ fname : Callable [[BaseLearner ], str ] | Sequence [str ],
435
426
compress : bool = True ,
436
427
) -> None :
437
428
"""Save the data of the child learners into pickle files
@@ -473,7 +464,7 @@ def save(
473
464
474
465
def load (
475
466
self ,
476
- fname : Union [ Callable [[BaseLearner ], str ], Sequence [str ] ],
467
+ fname : Callable [[BaseLearner ], str ] | Sequence [str ],
477
468
compress : bool = True ,
478
469
) -> None :
479
470
"""Load the data of the child learners from pickle files
@@ -499,20 +490,20 @@ def load(
499
490
for l in self .learners :
500
491
l .load (fname (l ), compress = compress )
501
492
502
- def _get_data (self ) -> List [Any ]:
493
+ def _get_data (self ) -> list [Any ]:
503
494
return [l ._get_data () for l in self .learners ]
504
495
505
- def _set_data (self , data : List [Any ]):
496
+ def _set_data (self , data : list [Any ]):
506
497
for l , _data in zip (self .learners , data ):
507
498
l ._set_data (_data )
508
499
509
- def __getstate__ (self ) -> Tuple [ List [BaseLearner ], CDIMS_TYPE , STRATEGY_TYPE ]:
500
+ def __getstate__ (self ) -> tuple [ list [BaseLearner ], CDIMS_TYPE , STRATEGY_TYPE ]:
510
501
return (
511
502
self .learners ,
512
503
self ._cdims_default ,
513
504
self .strategy ,
514
505
)
515
506
516
- def __setstate__ (self , state : Tuple [ List [BaseLearner ], CDIMS_TYPE , STRATEGY_TYPE ]):
507
+ def __setstate__ (self , state : tuple [ list [BaseLearner ], CDIMS_TYPE , STRATEGY_TYPE ]):
517
508
learners , cdims , strategy = state
518
509
self .__init__ (learners , cdims = cdims , strategy = strategy )
0 commit comments