Skip to content

Commit 4a363a2

Browse files
committed
change function signature to: f(x_seed: Tuple[int, float]) for AverageLearner1D
1 parent 16c028b commit 4a363a2

File tree

3 files changed

+75
-56
lines changed

3 files changed

+75
-56
lines changed

adaptive/learner/average_learner1D.py

Lines changed: 64 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from collections import defaultdict
22
from copy import deepcopy
33
from math import hypot
4+
from numbers import Number
5+
from typing import Dict, List, Sequence, Tuple, Union
46

57
import numpy as np
68
import scipy.stats
@@ -10,6 +12,10 @@
1012
from adaptive.learner.learner1D import Learner1D, _get_intervals
1113
from adaptive.notebook_integration import ensure_holoviews
1214

15+
Point = Tuple[int, Number]
16+
Points = List[Point]
17+
Value = Union[Number, Sequence[Number]]
18+
1319

1420
class AverageLearner1D(Learner1D):
1521
"""Learns and predicts a noisy function 'f:ℝ → ℝ^N'.
@@ -77,7 +83,7 @@ def __init__(
7783
self.neighbor_sampling = neighbor_sampling
7884

7985
# Contains all samples f(x) for each
80-
# point x in the form {x0:[f_0(x0), f_1(x0), ...], ...}
86+
# point x in the form {x0: {0: f_0(x0), 1: f_1(x0), ...}, ...}
8187
self._data_samples = SortedDict()
8288
# Contains the number of samples taken
8389
# at each point x in the form {x0: n0, x1: n1, ...}
@@ -95,17 +101,17 @@ def __init__(
95101
self.rescaled_error = decreasing_dict()
96102

97103
@property
98-
def nsamples(self):
104+
def nsamples(self) -> int:
99105
"""Returns the total number of samples"""
100106
return sum(self._number_samples.values())
101107

102108
@property
103-
def min_samples_per_point(self):
109+
def min_samples_per_point(self) -> int:
104110
if not self._number_samples:
105111
return 0
106112
return min(self._number_samples.values())
107113

108-
def ask(self, n, tell_pending=True):
114+
def ask(self, n: int, tell_pending: bool = True) -> Tuple[Points, List[float]]:
109115
"""Return 'n' points that are expected to maximally reduce the loss."""
110116
# If some point is undersampled, resample it
111117
if len(self._undersampled_points):
@@ -133,32 +139,34 @@ def ask(self, n, tell_pending=True):
133139

134140
return points, loss_improvements
135141

136-
def _ask_for_more_samples(self, x, n):
142+
def _ask_for_more_samples(self, x: Number, n: int) -> Tuple[Points, List[float]]:
137143
"""When asking for n points, the learner returns n times an existing point
138144
to be resampled, since in general n << min_samples and this point will
139145
need to be resampled many more times"""
140-
points = [x] * n
146+
n_existing = self._number_samples.get(x, 0)
147+
points = [(seed + n_existing, x) for seed in range(n)]
148+
141149
loss_improvements = [0] * n # We set the loss_improvements of resamples to 0
142150
return points, loss_improvements
143151

144-
def _ask_for_new_point(self, n):
152+
def _ask_for_new_point(self, n: int) -> Tuple[Points, List[float]]:
145153
"""When asking for n new points, the learner returns n times a single
146154
new point, since in general n << min_samples and this point will need
147155
to be resampled many more times"""
148156
points, loss_improvements = self._ask_points_without_adding(1)
149-
points = points * n
157+
points = [(seed, x) for seed, x in zip(range(n), n * points)]
150158
loss_improvements = loss_improvements + [0] * (n - 1)
151159
return points, loss_improvements
152160

153-
def tell_pending(self, x):
154-
if x in self.data:
155-
self.pending_points.add(x)
156-
else:
157-
self.pending_points.add(x)
161+
def tell_pending(self, seed_x: Point) -> None:
162+
_, x = seed_x
163+
self.pending_points.add(seed_x)
164+
if x not in self.data:
158165
self._update_neighbors(x, self.neighbors_combined)
159166
self._update_losses(x, real=False)
160167

161-
def tell(self, x, y):
168+
def tell(self, seed_x: Point, y: Value) -> None:
169+
seed, x = seed_x
162170
if y is None:
163171
raise TypeError(
164172
"Y-value may not be None, use learner.tell_pending(x)"
@@ -170,13 +178,13 @@ def tell(self, x, y):
170178

171179
if x not in self.data:
172180
self._update_data(x, y, "new")
173-
self._update_data_structures(x, y, "new")
174-
else:
181+
self._update_data_structures(seed_x, y, "new")
182+
elif seed not in self._data_samples[x]: # check if the seed is new
175183
self._update_data(x, y, "resampled")
176-
self._update_data_structures(x, y, "resampled")
177-
self.pending_points.discard(x)
184+
self._update_data_structures(seed_x, y, "resampled")
185+
self.pending_points.discard(seed_x)
178186

179-
def _update_rescaled_error_in_mean(self, x, point_type: str) -> None:
187+
def _update_rescaled_error_in_mean(self, x: Number, point_type: str) -> None:
180188
"""Updates ``self.rescaled_error``.
181189
182190
Parameters
@@ -213,17 +221,18 @@ def _update_rescaled_error_in_mean(self, x, point_type: str) -> None:
213221
norm = min(d_left, d_right)
214222
self.rescaled_error[x] = self.error[x] / norm
215223

216-
def _update_data(self, x, y, point_type: str):
224+
def _update_data(self, x: Number, y: Value, point_type: str) -> None:
217225
if point_type == "new":
218226
self.data[x] = y
219227
elif point_type == "resampled":
220228
n = len(self._data_samples[x])
221229
new_average = self.data[x] * n / (n + 1) + y / (n + 1)
222230
self.data[x] = new_average
223231

224-
def _update_data_structures(self, x, y, point_type: str):
232+
def _update_data_structures(self, seed_x: Point, y: Value, point_type: str) -> None:
233+
seed, x = seed_x
225234
if point_type == "new":
226-
self._data_samples[x] = [y]
235+
self._data_samples[x] = {seed: y}
227236

228237
if not self.bounds[0] <= x <= self.bounds[1]:
229238
return
@@ -247,7 +256,7 @@ def _update_data_structures(self, x, y, point_type: str):
247256
self._update_rescaled_error_in_mean(x, "new")
248257

249258
elif point_type == "resampled":
250-
self._data_samples[x].append(y)
259+
self._data_samples[x][seed] = y
251260
ns = self._number_samples
252261
ns[x] += 1
253262
n = ns[x]
@@ -268,7 +277,7 @@ def _update_data_structures(self, x, y, point_type: str):
268277
# the std of the mean multiplied by a t-Student factor to ensure that
269278
# the mean value lies within the correct interval of confidence
270279
y_avg = self.data[x]
271-
ys = self._data_samples[x]
280+
ys = self._data_samples[x].values()
272281
self.error[x] = self._calc_error_in_mean(ys, y_avg, n)
273282
self._update_distances(x)
274283
self._update_rescaled_error_in_mean(x, "resampled")
@@ -288,15 +297,15 @@ def _update_data_structures(self, x, y, point_type: str):
288297
self._update_interpolated_loss_in_interval(*interval)
289298
self._oldscale = deepcopy(self._scale)
290299

291-
def _update_distances(self, x):
300+
def _update_distances(self, x: Number) -> None:
292301
x_left, x_right = self.neighbors[x]
293302
y = self.data[x]
294303
if x_left is not None:
295304
self._distances[x_left] = hypot((x - x_left), (y - self.data[x_left]))
296305
if x_right is not None:
297306
self._distances[x] = hypot((x_right - x), (self.data[x_right] - y))
298307

299-
def _update_losses_resampling(self, x, real=True):
308+
def _update_losses_resampling(self, x: Number, real=True) -> None:
300309
"""Update all losses that depend on x, whenever the new point is a re-sampled point."""
301310
# (x_left, x_right) are the "real" neighbors of 'x'.
302311
x_left, x_right = self._find_neighbors(x, self.neighbors)
@@ -325,42 +334,43 @@ def _update_losses_resampling(self, x, real=True):
325334
if (b is not None) and right_loss_is_unknown:
326335
self.losses_combined[x, b] = float("inf")
327336

328-
def _calc_error_in_mean(self, ys, y_avg, n):
337+
def _calc_error_in_mean(self, ys: Sequence[Value], y_avg: Value, n: int) -> float:
329338
variance_in_mean = sum((y - y_avg) ** 2 for y in ys) / (n - 1)
330339
t_student = scipy.stats.t.ppf(1 - self.alpha, df=n - 1)
331340
return t_student * (variance_in_mean / n) ** 0.5
332341

333-
def tell_many(self, xs, ys):
342+
def tell_many(self, xs: Points, ys: Sequence[Value]) -> None:
334343
# Check that all x are within the bounds
335-
if not np.prod([x >= self.bounds[0] and x <= self.bounds[1] for x in xs]):
344+
if not np.prod([x >= self.bounds[0] and x <= self.bounds[1] for _, x in xs]):
336345
raise ValueError(
337346
"x value out of bounds, "
338347
"remove x or enlarge the bounds of the learner"
339348
)
340349

341350
# Create a mapping of points to a list of samples
342-
mapping = defaultdict(list)
343-
for x, y in zip(xs, ys):
344-
mapping[x].append(y)
345-
346-
for x, ys in mapping.items():
347-
if len(ys) == 1:
348-
self.tell(x, ys[0])
349-
elif len(ys) > 1:
351+
mapping = defaultdict(lambda: defaultdict(dict))
352+
for (seed, x), y in zip(xs, ys):
353+
mapping[x][seed] = y
354+
355+
for x, seed_y_mapping in mapping.items():
356+
if len(seed_y_mapping) == 1:
357+
seed, y = list(seed_y_mapping.items())[0]
358+
self.tell((seed, x), y)
359+
elif len(seed_y_mapping) > 1:
350360
# If we stored more than 1 y-value for the previous x,
351361
# use a more efficient routine to tell many samples
352362
# simultaneously, before we move on to a new x
353-
self.tell_many_at_point(x, ys)
363+
self.tell_many_at_point(x, seed_y_mapping)
354364

355-
def tell_many_at_point(self, x, ys):
365+
def tell_many_at_point(self, x: float, seed_y_mapping: Dict[int, Value]) -> None:
356366
"""Tell the learner about many samples at a certain location x.
357367
358368
Parameters
359369
----------
360370
x : float
361371
Value from the function domain.
362-
ys : List[float]
363-
List of data samples at ``x``.
372+
seed_y_mapping : Dict[int, Value]
373+
Dictionary of ``seed`` -> ``y`` at ``x``.
364374
"""
365375
# Check x is within the bounds
366376
if not np.prod(x >= self.bounds[0] and x <= self.bounds[1]):
@@ -369,16 +379,20 @@ def tell_many_at_point(self, x, ys):
369379
"remove x or enlarge the bounds of the learner"
370380
)
371381

372-
ys = list(ys) # cast to list *and* make a copy
373382
# If x is a new point:
374383
if x not in self.data:
375-
y = ys.pop(0)
384+
# we make a copy because we don't want to modify the original dict
385+
seed_y_mapping = seed_y_mapping.copy()
386+
seed = next(iter(seed_y_mapping))
387+
y = seed_y_mapping.pop(seed)
376388
self._update_data(x, y, "new")
377-
self._update_data_structures(x, y, "new")
389+
self._update_data_structures((seed, x), y, "new")
390+
391+
ys = list(seed_y_mapping.values()) # cast to list *and* make a copy
378392

379393
# If x is not a new point or if there were more than 1 sample in ys:
380394
if len(ys) > 0:
381-
self._data_samples[x].extend(ys)
395+
self._data_samples[x].update(seed_y_mapping)
382396
n = len(ys) + self._number_samples[x]
383397
self.data[x] = (
384398
np.mean(ys) * len(ys) + self.data[x] * self._number_samples[x]
@@ -390,24 +404,24 @@ def tell_many_at_point(self, x, ys):
390404
if n > self.min_samples:
391405
self._undersampled_points.discard(x)
392406
self.error[x] = self._calc_error_in_mean(
393-
self._data_samples[x], self.data[x], n
407+
self._data_samples[x].values(), self.data[x], n
394408
)
395409
self._update_distances(x)
396410
self._update_rescaled_error_in_mean(x, "resampled")
397411
if self.error[x] <= self.min_error or n >= self.max_samples:
398412
self.rescaled_error.pop(x, None)
399-
self._update_scale(x, min(self._data_samples[x]))
400-
self._update_scale(x, max(self._data_samples[x]))
413+
self._update_scale(x, min(self._data_samples[x].values()))
414+
self._update_scale(x, max(self._data_samples[x].values()))
401415
self._update_losses_resampling(x, real=True)
402416
if self._scale[1] > self._recompute_losses_factor * self._oldscale[1]:
403417
for interval in reversed(self.losses):
404418
self._update_interpolated_loss_in_interval(*interval)
405419
self._oldscale = deepcopy(self._scale)
406420

407-
def _get_data(self):
421+
def _get_data(self) -> SortedDict:
408422
return self._data_samples
409423

410-
def _set_data(self, data):
424+
def _set_data(self, data: SortedDict) -> None:
411425
if data:
412426
for x, samples in data.items():
413427
self.tell_many_at_point(x, samples)

adaptive/tests/test_average_learner1d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_tell_many_at_point():
3737
for k, v1 in learner._data_samples.items():
3838
v2 = control._data_samples[k]
3939
assert len(v1) == len(v2)
40-
np.testing.assert_almost_equal(np.sort(v1), np.sort(v2))
40+
np.testing.assert_almost_equal(sorted(v1.values()), sorted(v2.values()))
4141

4242
assert learner._bbox[0] == control._bbox[0]
4343
assert learner._bbox[1] == control._bbox[1]

adaptive/tests/test_learners.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,12 @@ def gaussian(n):
166166

167167
@learn_with(AverageLearner1D, bounds=[-2, 2])
168168
def noisy_peak(
169-
x,
169+
seed_x,
170170
sigma: uniform(1.5, 2.5),
171171
peak_width: uniform(0.04, 0.06),
172172
offset: uniform(-0.6, -0.3),
173173
):
174+
seed, x = seed_x
174175
y = x ** 3 - x + 3 * peak_width ** 2 / (peak_width ** 2 + (x - offset) ** 2)
175176
noise = np.random.normal(0, sigma)
176177
return y + noise
@@ -411,13 +412,17 @@ def test_point_adding_order_is_irrelevant(learner_type, f, learner_kwargs):
411412
learner.tell(*p)
412413

413414
M = random.randint(10, 30)
414-
pls = zip(*learner.ask(M))
415-
cpls = zip(*control.ask(M))
415+
pls = sorted(zip(*learner.ask(M)))
416+
cpls = sorted(zip(*control.ask(M)))
416417
# Point ordering within a single call to 'ask'
417418
# is not guaranteed to be the same by the API.
418419
# We compare the sorted points instead of set, because the points
419420
# should only be identical up to machine precision.
420-
np.testing.assert_almost_equal(sorted(pls), sorted(cpls))
421+
if isinstance(pls[0][0], tuple):
422+
# This is the case for AverageLearner1D
423+
pls = [(*x, y) for x, y in pls]
424+
cpls = [(*x, y) for x, y in cpls]
425+
np.testing.assert_almost_equal(pls, cpls)
421426

422427

423428
# XXX: the Learner2D fails with ~50% chance
@@ -473,7 +478,7 @@ def test_learner_performance_is_invariant_under_scaling(
473478
l_kwargs["bounds"] = xscale * np.array(l_kwargs["bounds"])
474479
learner = learner_type(lambda x: yscale * f(np.array(x) / xscale), **l_kwargs)
475480

476-
if learner_type in [Learner1D, LearnerND]:
481+
if learner_type in [Learner1D, LearnerND, AverageLearner1D]:
477482
learner._recompute_losses_factor = 1
478483
control._recompute_losses_factor = 1
479484

0 commit comments

Comments
 (0)