Skip to content

Commit bb7960d

Browse files
committed
bugfix in tell_many_at_point()
The new value of the data at point adopted the value of the mean of the new data samples, instead of the mean over all samples (new and old). Fixed!
1 parent 616adc8 commit bb7960d

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

adaptive/learner/average_learner1D.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,6 @@ def tell_many_at_point(self, x, ys):
370370
)
371371

372372
ys = list(ys) # cast to list *and* make a copy
373-
y_avg = np.mean(ys)
374373
# If x is a new point:
375374
if x not in self.data:
376375
y = ys.pop(0)
@@ -379,21 +378,23 @@ def tell_many_at_point(self, x, ys):
379378

380379
# If x is not a new point or if there were more than 1 sample in ys:
381380
if len(ys) > 0:
382-
self.data[x] = y_avg
383381
self._data_samples[x].extend(ys)
384-
n = len(self._data_samples[x])
382+
n = len(ys)+self._number_samples[x]
383+
# Same as n=len(self._data_samples[x]) but faster
384+
self.data[x] = (np.mean(ys)*len(ys) + self.data[x]*self._number_samples[x])/n
385+
# Same as self.data[x]=np.mean(self._data_samples[x]) but faster
385386
self._number_samples[x] = n
386387
# `self._update_data(x, y, "new")` included the point
387388
# in _undersampled_points. We remove it if there are
388389
# more than min_samples samples, disregarding neighbor_sampling.
389390
if n > self.min_samples:
390391
self._undersampled_points.discard(x)
391-
self.error[x] = self._calc_error_in_mean(self._data_samples[x], y_avg, n)
392+
self.error[x] = self._calc_error_in_mean(self._data_samples[x], self.data[x], n)
392393
self._update_distances(x)
393394
self._update_rescaled_error_in_mean(x, "resampled")
394395
if self.error[x] <= self.min_error or n >= self.max_samples:
395396
self.rescaled_error.pop(x, None)
396-
super()._update_scale(x, y_avg)
397+
super()._update_scale(x, self.data[x])
397398
self._update_losses_resampling(x, real=True)
398399
if self._scale[1] > self._recompute_losses_factor * self._oldscale[1]:
399400
for interval in reversed(self.losses):

0 commit comments

Comments
 (0)