3
3
from collections import defaultdict
4
4
from copy import deepcopy
5
5
from math import hypot
6
- from typing import (
7
- Callable ,
8
- DefaultDict ,
9
- Dict ,
10
- List ,
11
- Optional ,
12
- Sequence ,
13
- Set ,
14
- Tuple ,
15
- Union ,
16
- )
6
+ from typing import Callable , DefaultDict , Dict , List , Optional , Sequence , Set , Tuple
17
7
18
8
import numpy as np
19
9
import scipy .stats
22
12
23
13
from adaptive .learner .learner1D import Learner1D , _get_intervals
24
14
from adaptive .notebook_integration import ensure_holoviews
15
+ from adaptive .types import Real
25
16
26
- number = Union [int , float , np .int_ , np .float_ ]
27
- Point = Tuple [int , number ]
17
+ Point = Tuple [int , Real ]
28
18
Points = List [Point ]
29
19
30
20
__all__ : List [str ] = ["AverageLearner1D" ]
@@ -45,7 +35,7 @@ class AverageLearner1D(Learner1D):
45
35
If not provided, then a default is used, which uses the scaled distance
46
36
in the x-y plane as the loss. See the notes for more details
47
37
of `adaptive.Learner1D` for more details.
48
- delta : float
38
+ delta : float, optional, default 0.2
49
39
This parameter controls the resampling condition. A point is resampled
50
40
if its uncertainty is larger than delta times the smallest neighboring
51
41
interval.
@@ -75,10 +65,10 @@ class AverageLearner1D(Learner1D):
75
65
76
66
def __init__ (
77
67
self ,
78
- function : Callable [[Tuple [int , number ]], number ],
79
- bounds : Tuple [number , number ],
68
+ function : Callable [[Tuple [int , Real ]], Real ],
69
+ bounds : Tuple [Real , Real ],
80
70
loss_per_interval : Optional [
81
- Callable [[Sequence [number ], Sequence [number ]], float ]
71
+ Callable [[Sequence [Real ], Sequence [Real ]], float ]
82
72
] = None ,
83
73
delta : float = 0.2 ,
84
74
alpha : float = 0.005 ,
@@ -115,15 +105,15 @@ def __init__(
115
105
self ._number_samples = SortedDict ()
116
106
# This set contains the points x that have less than min_samples
117
107
# samples or less than a (neighbor_sampling*100)% of their neighbors
118
- self ._undersampled_points : Set [number ] = set ()
108
+ self ._undersampled_points : Set [Real ] = set ()
119
109
# Contains the error in the estimate of the
120
110
# mean at each point x in the form {x0: error(x0), ...}
121
- self .error : ItemSortedDict [number , float ] = decreasing_dict ()
111
+ self .error : ItemSortedDict [Real , float ] = decreasing_dict ()
122
112
# Distance between two neighboring points in the
123
113
# form {xi: ((xii-xi)^2 + (yii-yi)^2)^0.5, ...}
124
- self ._distances : ItemSortedDict [number , float ] = decreasing_dict ()
114
+ self ._distances : ItemSortedDict [Real , float ] = decreasing_dict ()
125
115
# {xii: error[xii]/min(_distances[xi], _distances[xii], ...}
126
- self .rescaled_error : ItemSortedDict [number , float ] = decreasing_dict ()
116
+ self .rescaled_error : ItemSortedDict [Real , float ] = decreasing_dict ()
127
117
128
118
@property
129
119
def nsamples (self ) -> int :
@@ -165,7 +155,7 @@ def ask(self, n: int, tell_pending: bool = True) -> Tuple[Points, List[float]]:
165
155
166
156
return points , loss_improvements
167
157
168
- def _ask_for_more_samples (self , x : number , n : int ) -> Tuple [Points , List [float ]]:
158
+ def _ask_for_more_samples (self , x : Real , n : int ) -> Tuple [Points , List [float ]]:
169
159
"""When asking for n points, the learner returns n times an existing point
170
160
to be resampled, since in general n << min_samples and this point will
171
161
need to be resampled many more times"""
@@ -200,7 +190,7 @@ def tell_pending(self, seed_x: Point) -> None:
200
190
self ._update_neighbors (x , self .neighbors_combined )
201
191
self ._update_losses (x , real = False )
202
192
203
- def tell (self , seed_x : Point , y : number ) -> None :
193
+ def tell (self , seed_x : Point , y : Real ) -> None :
204
194
seed , x = seed_x
205
195
if y is None :
206
196
raise TypeError (
@@ -216,7 +206,7 @@ def tell(self, seed_x: Point, y: number) -> None:
216
206
self ._update_data_structures (seed_x , y , "resampled" )
217
207
self .pending_points .discard (seed_x )
218
208
219
- def _update_rescaled_error_in_mean (self , x : number , point_type : str ) -> None :
209
+ def _update_rescaled_error_in_mean (self , x : Real , point_type : str ) -> None :
220
210
"""Updates ``self.rescaled_error``.
221
211
222
212
Parameters
@@ -253,17 +243,15 @@ def _update_rescaled_error_in_mean(self, x: number, point_type: str) -> None:
253
243
norm = min (d_left , d_right )
254
244
self .rescaled_error [x ] = self .error [x ] / norm
255
245
256
- def _update_data (self , x : number , y : number , point_type : str ) -> None :
246
+ def _update_data (self , x : Real , y : Real , point_type : str ) -> None :
257
247
if point_type == "new" :
258
248
self .data [x ] = y
259
249
elif point_type == "resampled" :
260
250
n = len (self ._data_samples [x ])
261
251
new_average = self .data [x ] * n / (n + 1 ) + y / (n + 1 )
262
252
self .data [x ] = new_average
263
253
264
- def _update_data_structures (
265
- self , seed_x : Point , y : number , point_type : str
266
- ) -> None :
254
+ def _update_data_structures (self , seed_x : Point , y : Real , point_type : str ) -> None :
267
255
seed , x = seed_x
268
256
if point_type == "new" :
269
257
self ._data_samples [x ] = {seed : y }
@@ -331,15 +319,15 @@ def _update_data_structures(
331
319
self ._update_interpolated_loss_in_interval (* interval )
332
320
self ._oldscale = deepcopy (self ._scale )
333
321
334
- def _update_distances (self , x : number ) -> None :
322
+ def _update_distances (self , x : Real ) -> None :
335
323
x_left , x_right = self .neighbors [x ]
336
324
y = self .data [x ]
337
325
if x_left is not None :
338
326
self ._distances [x_left ] = hypot ((x - x_left ), (y - self .data [x_left ]))
339
327
if x_right is not None :
340
328
self ._distances [x ] = hypot ((x_right - x ), (self .data [x_right ] - y ))
341
329
342
- def _update_losses_resampling (self , x : number , real = True ) -> None :
330
+ def _update_losses_resampling (self , x : Real , real = True ) -> None :
343
331
"""Update all losses that depend on x, whenever the new point is a re-sampled point."""
344
332
# (x_left, x_right) are the "real" neighbors of 'x'.
345
333
x_left , x_right = self ._find_neighbors (x , self .neighbors )
@@ -368,12 +356,12 @@ def _update_losses_resampling(self, x: number, real=True) -> None:
368
356
if (b is not None ) and right_loss_is_unknown :
369
357
self .losses_combined [x , b ] = float ("inf" )
370
358
371
- def _calc_error_in_mean (self , ys : Sequence [number ], y_avg : number , n : int ) -> float :
359
+ def _calc_error_in_mean (self , ys : Sequence [Real ], y_avg : Real , n : int ) -> float :
372
360
variance_in_mean = sum ((y - y_avg ) ** 2 for y in ys ) / (n - 1 )
373
361
t_student = scipy .stats .t .ppf (1 - self .alpha , df = n - 1 )
374
362
return t_student * (variance_in_mean / n ) ** 0.5
375
363
376
- def tell_many (self , xs : Points , ys : Sequence [number ]) -> None :
364
+ def tell_many (self , xs : Points , ys : Sequence [Real ]) -> None :
377
365
# Check that all x are within the bounds
378
366
# TODO: remove this requirement, all other learners add the data
379
367
# but ignore it going forward.
@@ -384,7 +372,7 @@ def tell_many(self, xs: Points, ys: Sequence[number]) -> None:
384
372
)
385
373
386
374
# Create a mapping of points to a list of samples
387
- mapping : DefaultDict [number , DefaultDict [int , number ]] = defaultdict (
375
+ mapping : DefaultDict [Real , DefaultDict [int , Real ]] = defaultdict (
388
376
lambda : defaultdict (dict )
389
377
)
390
378
for (seed , x ), y in zip (xs , ys ):
@@ -400,14 +388,14 @@ def tell_many(self, xs: Points, ys: Sequence[number]) -> None:
400
388
# simultaneously, before we move on to a new x
401
389
self .tell_many_at_point (x , seed_y_mapping )
402
390
403
- def tell_many_at_point (self , x : number , seed_y_mapping : Dict [int , number ]) -> None :
391
+ def tell_many_at_point (self , x : Real , seed_y_mapping : Dict [int , Real ]) -> None :
404
392
"""Tell the learner about many samples at a certain location x.
405
393
406
394
Parameters
407
395
----------
408
396
x : float
409
397
Value from the function domain.
410
- seed_y_mapping : Dict[int, number ]
398
+ seed_y_mapping : Dict[int, Real ]
411
399
Dictionary of ``seed`` -> ``y`` at ``x``.
412
400
"""
413
401
# Check x is within the bounds
@@ -456,10 +444,10 @@ def tell_many_at_point(self, x: number, seed_y_mapping: Dict[int, number]) -> No
456
444
self ._update_interpolated_loss_in_interval (* interval )
457
445
self ._oldscale = deepcopy (self ._scale )
458
446
459
- def _get_data (self ) -> SortedDict [number , number ]:
447
+ def _get_data (self ) -> SortedDict [Real , Real ]:
460
448
return self ._data_samples
461
449
462
- def _set_data (self , data : SortedDict [number , number ]) -> None :
450
+ def _set_data (self , data : SortedDict [Real , Real ]) -> None :
463
451
if data :
464
452
for x , samples in data .items ():
465
453
self .tell_many_at_point (x , samples )
0 commit comments