1
1
from collections import defaultdict
2
2
from copy import deepcopy
3
3
from math import hypot
4
+ from numbers import Number
5
+ from typing import Dict , List , Sequence , Tuple , Union
4
6
5
7
import numpy as np
6
8
import scipy .stats
10
12
from adaptive .learner .learner1D import Learner1D , _get_intervals
11
13
from adaptive .notebook_integration import ensure_holoviews
12
14
15
+ Point = Tuple [int , Number ]
16
+ Points = List [Point ]
17
+ Value = Union [Number , Sequence [Number ]]
18
+
13
19
14
20
class AverageLearner1D (Learner1D ):
15
21
"""Learns and predicts a noisy function 'f:ℝ → ℝ^N'.
@@ -77,7 +83,7 @@ def __init__(
77
83
self .neighbor_sampling = neighbor_sampling
78
84
79
85
# Contains all samples f(x) for each
80
- # point x in the form {x0:[ f_0(x0), f_1(x0), ...] , ...}
86
+ # point x in the form {x0: {0: f_0(x0), 1: f_1(x0), ...} , ...}
81
87
self ._data_samples = SortedDict ()
82
88
# Contains the number of samples taken
83
89
# at each point x in the form {x0: n0, x1: n1, ...}
@@ -95,17 +101,17 @@ def __init__(
95
101
self .rescaled_error = decreasing_dict ()
96
102
97
103
@property
98
- def nsamples (self ):
104
+ def nsamples (self ) -> int :
99
105
"""Returns the total number of samples"""
100
106
return sum (self ._number_samples .values ())
101
107
102
108
@property
103
- def min_samples_per_point (self ):
109
+ def min_samples_per_point (self ) -> int :
104
110
if not self ._number_samples :
105
111
return 0
106
112
return min (self ._number_samples .values ())
107
113
108
- def ask (self , n , tell_pending = True ):
114
+ def ask (self , n : int , tell_pending : bool = True ) -> Tuple [ Points , List [ float ]] :
109
115
"""Return 'n' points that are expected to maximally reduce the loss."""
110
116
# If some point is undersampled, resample it
111
117
if len (self ._undersampled_points ):
@@ -133,32 +139,34 @@ def ask(self, n, tell_pending=True):
133
139
134
140
return points , loss_improvements
135
141
136
- def _ask_for_more_samples (self , x , n ) :
142
+ def _ask_for_more_samples (self , x : Number , n : int ) -> Tuple [ Points , List [ float ]] :
137
143
"""When asking for n points, the learner returns n times an existing point
138
144
to be resampled, since in general n << min_samples and this point will
139
145
need to be resampled many more times"""
140
- points = [x ] * n
146
+ n_existing = self ._number_samples .get (x , 0 )
147
+ points = [(seed + n_existing , x ) for seed in range (n )]
148
+
141
149
loss_improvements = [0 ] * n # We set the loss_improvements of resamples to 0
142
150
return points , loss_improvements
143
151
144
- def _ask_for_new_point (self , n ) :
152
+ def _ask_for_new_point (self , n : int ) -> Tuple [ Points , List [ float ]] :
145
153
"""When asking for n new points, the learner returns n times a single
146
154
new point, since in general n << min_samples and this point will need
147
155
to be resampled many more times"""
148
156
points , loss_improvements = self ._ask_points_without_adding (1 )
149
- points = points * n
157
+ points = [( seed , x ) for seed , x in zip ( range ( n ), n * points )]
150
158
loss_improvements = loss_improvements + [0 ] * (n - 1 )
151
159
return points , loss_improvements
152
160
153
- def tell_pending (self , x ):
154
- if x in self .data :
155
- self .pending_points .add (x )
156
- else :
157
- self .pending_points .add (x )
161
+ def tell_pending (self , seed_x : Point ) -> None :
162
+ _ , x = seed_x
163
+ self .pending_points .add (seed_x )
164
+ if x not in self .data :
158
165
self ._update_neighbors (x , self .neighbors_combined )
159
166
self ._update_losses (x , real = False )
160
167
161
- def tell (self , x , y ):
168
+ def tell (self , seed_x : Point , y : Value ) -> None :
169
+ seed , x = seed_x
162
170
if y is None :
163
171
raise TypeError (
164
172
"Y-value may not be None, use learner.tell_pending(x)"
@@ -170,13 +178,13 @@ def tell(self, x, y):
170
178
171
179
if x not in self .data :
172
180
self ._update_data (x , y , "new" )
173
- self ._update_data_structures (x , y , "new" )
174
- else :
181
+ self ._update_data_structures (seed_x , y , "new" )
182
+ elif seed not in self . _data_samples [ x ]: # check if the seed is new
175
183
self ._update_data (x , y , "resampled" )
176
- self ._update_data_structures (x , y , "resampled" )
177
- self .pending_points .discard (x )
184
+ self ._update_data_structures (seed_x , y , "resampled" )
185
+ self .pending_points .discard (seed_x )
178
186
179
- def _update_rescaled_error_in_mean (self , x , point_type : str ) -> None :
187
+ def _update_rescaled_error_in_mean (self , x : Number , point_type : str ) -> None :
180
188
"""Updates ``self.rescaled_error``.
181
189
182
190
Parameters
@@ -213,17 +221,18 @@ def _update_rescaled_error_in_mean(self, x, point_type: str) -> None:
213
221
norm = min (d_left , d_right )
214
222
self .rescaled_error [x ] = self .error [x ] / norm
215
223
216
- def _update_data (self , x , y , point_type : str ):
224
+ def _update_data (self , x : Number , y : Value , point_type : str ) -> None :
217
225
if point_type == "new" :
218
226
self .data [x ] = y
219
227
elif point_type == "resampled" :
220
228
n = len (self ._data_samples [x ])
221
229
new_average = self .data [x ] * n / (n + 1 ) + y / (n + 1 )
222
230
self .data [x ] = new_average
223
231
224
- def _update_data_structures (self , x , y , point_type : str ):
232
+ def _update_data_structures (self , seed_x : Point , y : Value , point_type : str ) -> None :
233
+ seed , x = seed_x
225
234
if point_type == "new" :
226
- self ._data_samples [x ] = [ y ]
235
+ self ._data_samples [x ] = { seed : y }
227
236
228
237
if not self .bounds [0 ] <= x <= self .bounds [1 ]:
229
238
return
@@ -247,7 +256,7 @@ def _update_data_structures(self, x, y, point_type: str):
247
256
self ._update_rescaled_error_in_mean (x , "new" )
248
257
249
258
elif point_type == "resampled" :
250
- self ._data_samples [x ]. append ( y )
259
+ self ._data_samples [x ][ seed ] = y
251
260
ns = self ._number_samples
252
261
ns [x ] += 1
253
262
n = ns [x ]
@@ -268,7 +277,7 @@ def _update_data_structures(self, x, y, point_type: str):
268
277
# the std of the mean multiplied by a t-Student factor to ensure that
269
278
# the mean value lies within the correct interval of confidence
270
279
y_avg = self .data [x ]
271
- ys = self ._data_samples [x ]
280
+ ys = self ._data_samples [x ]. values ()
272
281
self .error [x ] = self ._calc_error_in_mean (ys , y_avg , n )
273
282
self ._update_distances (x )
274
283
self ._update_rescaled_error_in_mean (x , "resampled" )
@@ -288,15 +297,15 @@ def _update_data_structures(self, x, y, point_type: str):
288
297
self ._update_interpolated_loss_in_interval (* interval )
289
298
self ._oldscale = deepcopy (self ._scale )
290
299
291
- def _update_distances (self , x ) :
300
+ def _update_distances (self , x : Number ) -> None :
292
301
x_left , x_right = self .neighbors [x ]
293
302
y = self .data [x ]
294
303
if x_left is not None :
295
304
self ._distances [x_left ] = hypot ((x - x_left ), (y - self .data [x_left ]))
296
305
if x_right is not None :
297
306
self ._distances [x ] = hypot ((x_right - x ), (self .data [x_right ] - y ))
298
307
299
- def _update_losses_resampling (self , x , real = True ):
308
+ def _update_losses_resampling (self , x : Number , real = True ) -> None :
300
309
"""Update all losses that depend on x, whenever the new point is a re-sampled point."""
301
310
# (x_left, x_right) are the "real" neighbors of 'x'.
302
311
x_left , x_right = self ._find_neighbors (x , self .neighbors )
@@ -325,42 +334,43 @@ def _update_losses_resampling(self, x, real=True):
325
334
if (b is not None ) and right_loss_is_unknown :
326
335
self .losses_combined [x , b ] = float ("inf" )
327
336
328
- def _calc_error_in_mean (self , ys , y_avg , n ) :
337
+ def _calc_error_in_mean (self , ys : Sequence [ Value ] , y_avg : Value , n : int ) -> float :
329
338
variance_in_mean = sum ((y - y_avg ) ** 2 for y in ys ) / (n - 1 )
330
339
t_student = scipy .stats .t .ppf (1 - self .alpha , df = n - 1 )
331
340
return t_student * (variance_in_mean / n ) ** 0.5
332
341
333
- def tell_many (self , xs , ys ) :
342
+ def tell_many (self , xs : Points , ys : Sequence [ Value ]) -> None :
334
343
# Check that all x are within the bounds
335
- if not np .prod ([x >= self .bounds [0 ] and x <= self .bounds [1 ] for x in xs ]):
344
+ if not np .prod ([x >= self .bounds [0 ] and x <= self .bounds [1 ] for _ , x in xs ]):
336
345
raise ValueError (
337
346
"x value out of bounds, "
338
347
"remove x or enlarge the bounds of the learner"
339
348
)
340
349
341
350
# Create a mapping of points to a list of samples
342
- mapping = defaultdict (list )
343
- for x , y in zip (xs , ys ):
344
- mapping [x ].append (y )
345
-
346
- for x , ys in mapping .items ():
347
- if len (ys ) == 1 :
348
- self .tell (x , ys [0 ])
349
- elif len (ys ) > 1 :
351
+ mapping = defaultdict (lambda : defaultdict (dict ))
352
+ for (seed , x ), y in zip (xs , ys ):
353
+ mapping [x ][seed ] = y
354
+
355
+ for x , seed_y_mapping in mapping .items ():
356
+ if len (seed_y_mapping ) == 1 :
357
+ seed , y = list (seed_y_mapping .items ())[0 ]
358
+ self .tell ((seed , x ), y )
359
+ elif len (seed_y_mapping ) > 1 :
350
360
# If we stored more than 1 y-value for the previous x,
351
361
# use a more efficient routine to tell many samples
352
362
# simultaneously, before we move on to a new x
353
- self .tell_many_at_point (x , ys )
363
+ self .tell_many_at_point (x , seed_y_mapping )
354
364
355
- def tell_many_at_point (self , x , ys ) :
365
+ def tell_many_at_point (self , x : float , seed_y_mapping : Dict [ int , Value ]) -> None :
356
366
"""Tell the learner about many samples at a certain location x.
357
367
358
368
Parameters
359
369
----------
360
370
x : float
361
371
Value from the function domain.
362
- ys : List[float ]
363
- List of data samples at ``x``.
372
+ seed_y_mapping : Dict[int, Value ]
373
+ Dictionary of ``seed`` -> ``y`` at ``x``.
364
374
"""
365
375
# Check x is within the bounds
366
376
if not np .prod (x >= self .bounds [0 ] and x <= self .bounds [1 ]):
@@ -369,16 +379,20 @@ def tell_many_at_point(self, x, ys):
369
379
"remove x or enlarge the bounds of the learner"
370
380
)
371
381
372
- ys = list (ys ) # cast to list *and* make a copy
373
382
# If x is a new point:
374
383
if x not in self .data :
375
- y = ys .pop (0 )
384
+ # we make a copy because we don't want to modify the original dict
385
+ seed_y_mapping = seed_y_mapping .copy ()
386
+ seed = next (iter (seed_y_mapping ))
387
+ y = seed_y_mapping .pop (seed )
376
388
self ._update_data (x , y , "new" )
377
- self ._update_data_structures (x , y , "new" )
389
+ self ._update_data_structures ((seed , x ), y , "new" )
390
+
391
+ ys = list (seed_y_mapping .values ()) # cast to list *and* make a copy
378
392
379
393
# If x is not a new point or if there were more than 1 sample in ys:
380
394
if len (ys ) > 0 :
381
- self ._data_samples [x ].extend ( ys )
395
+ self ._data_samples [x ].update ( seed_y_mapping )
382
396
n = len (ys ) + self ._number_samples [x ]
383
397
self .data [x ] = (
384
398
np .mean (ys ) * len (ys ) + self .data [x ] * self ._number_samples [x ]
@@ -390,24 +404,24 @@ def tell_many_at_point(self, x, ys):
390
404
if n > self .min_samples :
391
405
self ._undersampled_points .discard (x )
392
406
self .error [x ] = self ._calc_error_in_mean (
393
- self ._data_samples [x ], self .data [x ], n
407
+ self ._data_samples [x ]. values () , self .data [x ], n
394
408
)
395
409
self ._update_distances (x )
396
410
self ._update_rescaled_error_in_mean (x , "resampled" )
397
411
if self .error [x ] <= self .min_error or n >= self .max_samples :
398
412
self .rescaled_error .pop (x , None )
399
- self ._update_scale (x , min (self ._data_samples [x ]))
400
- self ._update_scale (x , max (self ._data_samples [x ]))
413
+ self ._update_scale (x , min (self ._data_samples [x ]. values () ))
414
+ self ._update_scale (x , max (self ._data_samples [x ]. values () ))
401
415
self ._update_losses_resampling (x , real = True )
402
416
if self ._scale [1 ] > self ._recompute_losses_factor * self ._oldscale [1 ]:
403
417
for interval in reversed (self .losses ):
404
418
self ._update_interpolated_loss_in_interval (* interval )
405
419
self ._oldscale = deepcopy (self ._scale )
406
420
407
- def _get_data (self ):
421
+ def _get_data (self ) -> SortedDict :
408
422
return self ._data_samples
409
423
410
- def _set_data (self , data ) :
424
+ def _set_data (self , data : SortedDict ) -> None :
411
425
if data :
412
426
for x , samples in data .items ():
413
427
self .tell_many_at_point (x , samples )
0 commit comments