1
1
import math
2
+ import sys
2
3
from collections import defaultdict
3
4
from copy import deepcopy
4
5
from math import hypot
23
24
from adaptive .notebook_integration import ensure_holoviews
24
25
25
26
number = Union [int , float , np .int_ , np .float_ ]
26
-
27
27
Point = Tuple [int , number ]
28
28
Points = List [Point ]
29
- Value = Union [number , Sequence [number ], np .ndarray ]
30
29
31
- __all__ = ["AverageLearner1D" ]
30
+ __all__ : List [ str ] = ["AverageLearner1D" ]
32
31
33
32
34
33
class AverageLearner1D (Learner1D ):
@@ -76,7 +75,7 @@ class AverageLearner1D(Learner1D):
76
75
77
76
def __init__ (
78
77
self ,
79
- function : Callable [[Tuple [int , number ]], Value ],
78
+ function : Callable [[Tuple [int , number ]], number ],
80
79
bounds : Tuple [number , number ],
81
80
loss_per_interval : Optional [
82
81
Callable [[Sequence [number ], Sequence [number ]], float ]
@@ -85,7 +84,7 @@ def __init__(
85
84
alpha : float = 0.005 ,
86
85
neighbor_sampling : float = 0.3 ,
87
86
min_samples : int = 50 ,
88
- max_samples : int = np . inf ,
87
+ max_samples : int = sys . maxsize ,
89
88
min_error : float = 0 ,
90
89
):
91
90
if not (0 < delta <= 1 ):
@@ -201,16 +200,13 @@ def tell_pending(self, seed_x: Point) -> None:
201
200
self ._update_neighbors (x , self .neighbors_combined )
202
201
self ._update_losses (x , real = False )
203
202
204
- def tell (self , seed_x : Point , y : Value ) -> None :
203
+ def tell (self , seed_x : Point , y : number ) -> None :
205
204
seed , x = seed_x
206
205
if y is None :
207
206
raise TypeError (
208
207
"Y-value may not be None, use learner.tell_pending(x)"
209
208
"to indicate that this value is currently being calculated"
210
209
)
211
- # either it is a float/int, if not, try casting to a np.array
212
- if not isinstance (y , (float , int )):
213
- y = np .asarray (y , dtype = float )
214
210
215
211
if x not in self .data :
216
212
self ._update_data (x , y , "new" )
@@ -257,15 +253,17 @@ def _update_rescaled_error_in_mean(self, x: number, point_type: str) -> None:
257
253
norm = min (d_left , d_right )
258
254
self .rescaled_error [x ] = self .error [x ] / norm
259
255
260
- def _update_data (self , x : number , y : Value , point_type : str ) -> None :
256
+ def _update_data (self , x : number , y : number , point_type : str ) -> None :
261
257
if point_type == "new" :
262
258
self .data [x ] = y
263
259
elif point_type == "resampled" :
264
260
n = len (self ._data_samples [x ])
265
261
new_average = self .data [x ] * n / (n + 1 ) + y / (n + 1 )
266
262
self .data [x ] = new_average
267
263
268
- def _update_data_structures (self , seed_x : Point , y : Value , point_type : str ) -> None :
264
+ def _update_data_structures (
265
+ self , seed_x : Point , y : number , point_type : str
266
+ ) -> None :
269
267
seed , x = seed_x
270
268
if point_type == "new" :
271
269
self ._data_samples [x ] = {seed : y }
@@ -370,12 +368,12 @@ def _update_losses_resampling(self, x: number, real=True) -> None:
370
368
if (b is not None ) and right_loss_is_unknown :
371
369
self .losses_combined [x , b ] = float ("inf" )
372
370
373
- def _calc_error_in_mean (self , ys : Sequence [Value ], y_avg : Value , n : int ) -> float :
371
+ def _calc_error_in_mean (self , ys : Sequence [number ], y_avg : number , n : int ) -> float :
374
372
variance_in_mean = sum ((y - y_avg ) ** 2 for y in ys ) / (n - 1 )
375
373
t_student = scipy .stats .t .ppf (1 - self .alpha , df = n - 1 )
376
374
return t_student * (variance_in_mean / n ) ** 0.5
377
375
378
- def tell_many (self , xs : Points , ys : Sequence [Value ]) -> None :
376
+ def tell_many (self , xs : Points , ys : Sequence [number ]) -> None :
379
377
# Check that all x are within the bounds
380
378
# TODO: remove this requirement, all other learners add the data
381
379
# but ignore it going forward.
@@ -386,7 +384,7 @@ def tell_many(self, xs: Points, ys: Sequence[Value]) -> None:
386
384
)
387
385
388
386
# Create a mapping of points to a list of samples
389
- mapping : DefaultDict [number , DefaultDict [int , Value ]] = defaultdict (
387
+ mapping : DefaultDict [number , DefaultDict [int , number ]] = defaultdict (
390
388
lambda : defaultdict (dict )
391
389
)
392
390
for (seed , x ), y in zip (xs , ys ):
@@ -402,14 +400,14 @@ def tell_many(self, xs: Points, ys: Sequence[Value]) -> None:
402
400
# simultaneously, before we move on to a new x
403
401
self .tell_many_at_point (x , seed_y_mapping )
404
402
405
- def tell_many_at_point (self , x : float , seed_y_mapping : Dict [int , Value ]) -> None :
403
+ def tell_many_at_point (self , x : number , seed_y_mapping : Dict [int , number ]) -> None :
406
404
"""Tell the learner about many samples at a certain location x.
407
405
408
406
Parameters
409
407
----------
410
408
x : float
411
409
Value from the function domain.
412
- seed_y_mapping : Dict[int, Value ]
410
+ seed_y_mapping : Dict[int, number ]
413
411
Dictionary of ``seed`` -> ``y`` at ``x``.
414
412
"""
415
413
# Check x is within the bounds
@@ -458,10 +456,10 @@ def tell_many_at_point(self, x: float, seed_y_mapping: Dict[int, Value]) -> None
458
456
self ._update_interpolated_loss_in_interval (* interval )
459
457
self ._oldscale = deepcopy (self ._scale )
460
458
461
- def _get_data (self ) -> SortedDict [number , Value ]:
459
+ def _get_data (self ) -> SortedDict [number , number ]:
462
460
return self ._data_samples
463
461
464
- def _set_data (self , data : SortedDict [number , Value ]) -> None :
462
+ def _set_data (self , data : SortedDict [number , number ]) -> None :
465
463
if data :
466
464
for x , samples in data .items ():
467
465
self .tell_many_at_point (x , samples )
0 commit comments