2
2
from collections import defaultdict
3
3
from copy import deepcopy
4
4
from math import hypot
5
- from numbers import Number
6
- from typing import Dict , List , Sequence , Tuple , Union
5
+ from typing import (
6
+ Callable ,
7
+ DefaultDict ,
8
+ Dict ,
9
+ List ,
10
+ Optional ,
11
+ Sequence ,
12
+ Set ,
13
+ Tuple ,
14
+ Union ,
15
+ )
7
16
8
17
import numpy as np
9
18
import scipy .stats
13
22
from adaptive .learner .learner1D import Learner1D , _get_intervals
14
23
from adaptive .notebook_integration import ensure_holoviews
15
24
16
- Point = Tuple [int , Number ]
25
+ number = Union [int , float , np .int_ , np .float_ ]
26
+
27
+ Point = Tuple [int , number ]
17
28
Points = List [Point ]
18
- Value = Union [Number , Sequence [Number ]]
29
+ Value = Union [number , Sequence [number ], np .ndarray ]
30
+
31
+ __all__ = ["AverageLearner1D" ]
19
32
20
33
21
34
class AverageLearner1D (Learner1D ):
@@ -37,21 +50,21 @@ class AverageLearner1D(Learner1D):
37
50
This parameter controls the resampling condition. A point is resampled
38
51
if its uncertainty is larger than delta times the smallest neighboring
39
52
interval.
40
- We strongly recommend 0 < delta <= 1.
41
- alpha : float (0 < alpha < 1)
53
+ We strongly recommend `` 0 < delta <= 1`` .
54
+ alpha : float (0 < alpha < 1), default 0.005
42
55
The true value of the function at x is within the confidence interval
43
- [self.data[x] - self.error[x], self.data[x] +
44
- self.error[x]] with probability 1-2*alpha.
45
- We recommend to keep alpha=0.005.
46
- neighbor_sampling : float (0 < neighbor_sampling <= 1)
56
+ `` [self.data[x] - self.error[x], self.data[x] + self.error[x]]`` with
57
+ probability `` 1-2*alpha`` .
58
+ We recommend to keep `` alpha=0.005`` .
59
+ neighbor_sampling : float (0 < neighbor_sampling <= 1), default 0.3
47
60
Each new point is initially sampled at least a (neighbor_sampling*100)%
48
61
of the average number of samples of its neighbors.
49
- min_samples : int (min_samples > 0)
62
+ min_samples : int (min_samples > 0), default 50
50
63
Minimum number of samples at each point x. Each new point is initially
51
64
sampled at least min_samples times.
52
- max_samples : int (min_samples < max_samples)
65
+ max_samples : int (min_samples < max_samples), default np.inf
53
66
Maximum number of samples at each point x.
54
- min_error : float (min_error >= 0)
67
+ min_error : float (min_error >= 0), default 0
55
68
Minimum size of the confidence intervals. The true value of the
56
69
function at x is within the confidence interval [self.data[x] -
57
70
self.error[x], self.data[x] + self.error[x]] with
@@ -63,15 +76,17 @@ class AverageLearner1D(Learner1D):
63
76
64
77
def __init__ (
65
78
self ,
66
- function ,
67
- bounds ,
68
- loss_per_interval = None ,
69
- delta = 0.2 ,
70
- alpha = 0.005 ,
71
- neighbor_sampling = 0.3 ,
72
- min_samples = 50 ,
73
- max_samples = np .inf ,
74
- min_error = 0 ,
79
+ function : Callable [[Tuple [int , number ]], Value ],
80
+ bounds : Tuple [number , number ],
81
+ loss_per_interval : Optional [
82
+ Callable [[Sequence [number ], Sequence [number ]], float ]
83
+ ] = None ,
84
+ delta : float = 0.2 ,
85
+ alpha : float = 0.005 ,
86
+ neighbor_sampling : float = 0.3 ,
87
+ min_samples : int = 50 ,
88
+ max_samples : int = np .inf ,
89
+ min_error : float = 0 ,
75
90
):
76
91
if not (0 < delta <= 1 ):
77
92
raise ValueError ("Learner requires 0 < delta <= 1." )
@@ -101,15 +116,15 @@ def __init__(
101
116
self ._number_samples = SortedDict ()
102
117
# This set contains the points x that have less than min_samples
103
118
# samples or less than a (neighbor_sampling*100)% of their neighbors
104
- self ._undersampled_points = set ()
119
+ self ._undersampled_points : Set [ number ] = set ()
105
120
# Contains the error in the estimate of the
106
121
# mean at each point x in the form {x0: error(x0), ...}
107
- self .error = decreasing_dict ()
122
+ self .error : ItemSortedDict [ number , float ] = decreasing_dict ()
108
123
# Distance between two neighboring points in the
109
124
# form {xi: ((xii-xi)^2 + (yii-yi)^2)^0.5, ...}
110
- self ._distances = decreasing_dict ()
125
+ self ._distances : ItemSortedDict [ number , float ] = decreasing_dict ()
111
126
# {xii: error[xii]/min(_distances[xi], _distances[xii], ...}
112
- self .rescaled_error = decreasing_dict ()
127
+ self .rescaled_error : ItemSortedDict [ number , float ] = decreasing_dict ()
113
128
114
129
@property
115
130
def nsamples (self ) -> int :
@@ -151,7 +166,7 @@ def ask(self, n: int, tell_pending: bool = True) -> Tuple[Points, List[float]]:
151
166
152
167
return points , loss_improvements
153
168
154
- def _ask_for_more_samples (self , x : Number , n : int ) -> Tuple [Points , List [float ]]:
169
+ def _ask_for_more_samples (self , x : number , n : int ) -> Tuple [Points , List [float ]]:
155
170
"""When asking for n points, the learner returns n times an existing point
156
171
to be resampled, since in general n << min_samples and this point will
157
172
need to be resampled many more times"""
@@ -205,7 +220,7 @@ def tell(self, seed_x: Point, y: Value) -> None:
205
220
self ._update_data_structures (seed_x , y , "resampled" )
206
221
self .pending_points .discard (seed_x )
207
222
208
- def _update_rescaled_error_in_mean (self , x : Number , point_type : str ) -> None :
223
+ def _update_rescaled_error_in_mean (self , x : number , point_type : str ) -> None :
209
224
"""Updates ``self.rescaled_error``.
210
225
211
226
Parameters
@@ -242,7 +257,7 @@ def _update_rescaled_error_in_mean(self, x: Number, point_type: str) -> None:
242
257
norm = min (d_left , d_right )
243
258
self .rescaled_error [x ] = self .error [x ] / norm
244
259
245
- def _update_data (self , x : Number , y : Value , point_type : str ) -> None :
260
+ def _update_data (self , x : number , y : Value , point_type : str ) -> None :
246
261
if point_type == "new" :
247
262
self .data [x ] = y
248
263
elif point_type == "resampled" :
@@ -318,15 +333,15 @@ def _update_data_structures(self, seed_x: Point, y: Value, point_type: str) -> N
318
333
self ._update_interpolated_loss_in_interval (* interval )
319
334
self ._oldscale = deepcopy (self ._scale )
320
335
321
- def _update_distances (self , x : Number ) -> None :
336
+ def _update_distances (self , x : number ) -> None :
322
337
x_left , x_right = self .neighbors [x ]
323
338
y = self .data [x ]
324
339
if x_left is not None :
325
340
self ._distances [x_left ] = hypot ((x - x_left ), (y - self .data [x_left ]))
326
341
if x_right is not None :
327
342
self ._distances [x ] = hypot ((x_right - x ), (self .data [x_right ] - y ))
328
343
329
- def _update_losses_resampling (self , x : Number , real = True ) -> None :
344
+ def _update_losses_resampling (self , x : number , real = True ) -> None :
330
345
"""Update all losses that depend on x, whenever the new point is a re-sampled point."""
331
346
# (x_left, x_right) are the "real" neighbors of 'x'.
332
347
x_left , x_right = self ._find_neighbors (x , self .neighbors )
@@ -371,7 +386,9 @@ def tell_many(self, xs: Points, ys: Sequence[Value]) -> None:
371
386
)
372
387
373
388
# Create a mapping of points to a list of samples
374
- mapping = defaultdict (lambda : defaultdict (dict ))
389
+ mapping : DefaultDict [number , DefaultDict [int , Value ]] = defaultdict (
390
+ lambda : defaultdict (dict )
391
+ )
375
392
for (seed , x ), y in zip (xs , ys ):
376
393
mapping [x ][seed ] = y
377
394
@@ -411,7 +428,7 @@ def tell_many_at_point(self, x: float, seed_y_mapping: Dict[int, Value]) -> None
411
428
self ._update_data (x , y , "new" )
412
429
self ._update_data_structures ((seed , x ), y , "new" )
413
430
414
- ys = list (seed_y_mapping .values ()) # cast to list *and* make a copy
431
+ ys = np . array ( list (seed_y_mapping .values ()))
415
432
416
433
# If x is not a new point or if there were more than 1 sample in ys:
417
434
if len (ys ) > 0 :
@@ -441,10 +458,10 @@ def tell_many_at_point(self, x: float, seed_y_mapping: Dict[int, Value]) -> None
441
458
self ._update_interpolated_loss_in_interval (* interval )
442
459
self ._oldscale = deepcopy (self ._scale )
443
460
444
- def _get_data (self ) -> SortedDict :
461
+ def _get_data (self ) -> SortedDict [ number , Value ] :
445
462
return self ._data_samples
446
463
447
- def _set_data (self , data : SortedDict ) -> None :
464
+ def _set_data (self , data : SortedDict [ number , Value ] ) -> None :
448
465
if data :
449
466
for x , samples in data .items ():
450
467
self .tell_many_at_point (x , samples )
@@ -478,7 +495,7 @@ def plot(self):
478
495
return p .redim (x = dict (range = plot_bounds ))
479
496
480
497
481
- def decreasing_dict ():
498
+ def decreasing_dict () -> ItemSortedDict :
482
499
"""This initialization orders the dictionary from large to small values"""
483
500
484
501
def sorting_rule (key , value ):
0 commit comments