@@ -370,7 +370,6 @@ def tell_many_at_point(self, x, ys):
370
370
)
371
371
372
372
ys = list (ys ) # cast to list *and* make a copy
373
- y_avg = np .mean (ys )
374
373
# If x is a new point:
375
374
if x not in self .data :
376
375
y = ys .pop (0 )
@@ -379,21 +378,23 @@ def tell_many_at_point(self, x, ys):
379
378
380
379
# If x is not a new point or if there were more than 1 sample in ys:
381
380
if len (ys ) > 0 :
382
- self .data [x ] = y_avg
383
381
self ._data_samples [x ].extend (ys )
384
- n = len (self ._data_samples [x ])
382
+ n = len (ys )+ self ._number_samples [x ]
383
+ # Same as n=len(self._data_samples[x]) but faster
384
+ self .data [x ] = (np .mean (ys )* len (ys ) + self .data [x ]* self ._number_samples [x ])/ n
385
+ # Same as self.data[x]=np.mean(self._data_samples[x]) but faster
385
386
self ._number_samples [x ] = n
386
387
# `self._update_data(x, y, "new")` included the point
387
388
# in _undersampled_points. We remove it if there are
388
389
# more than min_samples samples, disregarding neighbor_sampling.
389
390
if n > self .min_samples :
390
391
self ._undersampled_points .discard (x )
391
- self .error [x ] = self ._calc_error_in_mean (self ._data_samples [x ], y_avg , n )
392
+ self .error [x ] = self ._calc_error_in_mean (self ._data_samples [x ], self . data [ x ] , n )
392
393
self ._update_distances (x )
393
394
self ._update_rescaled_error_in_mean (x , "resampled" )
394
395
if self .error [x ] <= self .min_error or n >= self .max_samples :
395
396
self .rescaled_error .pop (x , None )
396
- super ()._update_scale (x , y_avg )
397
+ super ()._update_scale (x , self . data [ x ] )
397
398
self ._update_losses_resampling (x , real = True )
398
399
if self ._scale [1 ] > self ._recompute_losses_factor * self ._oldscale [1 ]:
399
400
for interval in reversed (self .losses ):
0 commit comments