Skip to content

Commit cf8af90

Browse files
committed
change the way the loss_improvement is returned
1 parent 3bfb22d commit cf8af90

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

adaptive/learner/average_learner1D.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,17 +146,19 @@ def _ask_for_more_samples(self, x: Number, n: int) -> Tuple[Points, List[float]]
146146
need to be resampled many more times"""
147147
n_existing = self._number_samples.get(x, 0)
148148
points = [(seed + n_existing, x) for seed in range(n)]
149-
150-
loss_improvements = [0] * n # We set the loss_improvements of resamples to 0
149+
xl, xr = self.neighbors_combined[x]
150+
loss = (self.losses_combined[xl, x] + self.losses_combined[x, xr]) / 2
151+
loss_improvement = loss - loss * np.sqrt(n_existing) / np.sqrt(n_existing + n)
152+
loss_improvements = [loss_improvement / n] * n
151153
return points, loss_improvements
152154

153155
def _ask_for_new_point(self, n: int) -> Tuple[Points, List[float]]:
154156
"""When asking for n new points, the learner returns n times a single
155157
new point, since in general n << min_samples and this point will need
156158
to be resampled many more times"""
157-
points, loss_improvements = self._ask_points_without_adding(1)
159+
points, (loss_improvement,) = self._ask_points_without_adding(1)
158160
points = [(seed, x) for seed, x in zip(range(n), n * points)]
159-
loss_improvements = loss_improvements + [0] * (n - 1)
161+
loss_improvements = [loss_improvement / n] * n
160162
return points, loss_improvements
161163

162164
def tell_pending(self, seed_x: Point) -> None:

adaptive/tests/test_learners.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -434,11 +434,13 @@ def test_expected_loss_improvement_is_less_than_total_loss(
434434
"""The estimated loss improvement can never be greater than the total loss."""
435435
f = generate_random_parametrization(f)
436436
learner = learner_type(f, **learner_kwargs)
437-
N = random.randint(50, 100)
438-
xs, loss_improvements = learner.ask(N)
439-
440-
for x in xs:
441-
learner.tell(x, learner.function(x))
437+
for _ in range(2):
438+
# We do this twice to make sure that the AverageLearner1D
439+
# has two different points in `x`.
440+
N = random.randint(50, 100)
441+
xs, loss_improvements = learner.ask(N)
442+
for x in xs:
443+
learner.tell(x, learner.function(x))
442444

443445
M = random.randint(50, 100)
444446
_, loss_improvements = learner.ask(M)

0 commit comments

Comments
 (0)