Skip to content

Commit 3b85da5

Browse files
committed
fix bug where loss_improvement becomes nan
1 parent cf8af90 commit 3b85da5

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

adaptive/learner/average_learner1D.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
from collections import defaultdict
23
from copy import deepcopy
34
from math import hypot
@@ -147,8 +148,15 @@ def _ask_for_more_samples(self, x: Number, n: int) -> Tuple[Points, List[float]]
147148
n_existing = self._number_samples.get(x, 0)
148149
points = [(seed + n_existing, x) for seed in range(n)]
149150
xl, xr = self.neighbors_combined[x]
150-
loss = (self.losses_combined[xl, x] + self.losses_combined[x, xr]) / 2
151-
loss_improvement = loss - loss * np.sqrt(n_existing) / np.sqrt(n_existing + n)
151+
loss_left = self.losses_combined.get((xl, x), float("inf"))
152+
loss_right = self.losses_combined.get((x, xr), float("inf"))
153+
loss = (loss_left + loss_right) / 2
154+
if math.isinf(loss):
155+
loss_improvement = float("inf")
156+
else:
157+
loss_improvement = loss - loss * np.sqrt(n_existing) / np.sqrt(
158+
n_existing + n
159+
)
152160
loss_improvements = [loss_improvement / n] * n
153161
return points, loss_improvements
154162

0 commit comments

Comments
 (0)