Skip to content

Commit a00bf61

Browse files
committed
fix the point distribution issue by max(loss, -npoints)
1 parent a7a1822 commit a00bf61

File tree

1 file changed

+23
-23
lines changed

1 file changed

+23
-23
lines changed

adaptive/learner/balancing_learner.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -113,25 +113,33 @@ def strategy(self, strategy):
113113
' strategy="npoints" is implemented.')
114114

115115
def _ask_and_tell_based_on_loss_improvements(self, n):
116-
points = []
117-
loss_improvements = []
116+
chosen_points = []
117+
chosen_loss_improvements = []
118+
npoints_per_learner = defaultdict(int)
119+
118120
for _ in range(n):
119121
improvements_per_learner = []
120-
pairs = []
122+
points_per_learner = []
121123
for index, learner in enumerate(self.learners):
122124
if index not in self._points:
123125
self._points[index] = learner.ask(
124126
n=1, tell_pending=False)
125-
point, loss_improvement = self._points[index]
126-
improvements_per_learner.append(loss_improvement[0])
127-
pairs.append((index, point[0]))
128-
x, l = max(zip(pairs, improvements_per_learner),
129-
key=itemgetter(1))
130-
points.append(x)
131-
loss_improvements.append(l)
132-
self.tell_pending(x)
133-
134-
return points, loss_improvements
127+
points, loss_improvements = self._points[index]
128+
npoints = npoints_per_learner[index] + learner.npoints
129+
priority = (loss_improvements[0], -npoints)
130+
improvements_per_learner.append(priority)
131+
points_per_learner.append((index, points[0]))
132+
133+
# Chose the optimal improvement.
134+
(index, point), (loss_improvement, _) = max(
135+
zip(points_per_learner, improvements_per_learner),
136+
key=itemgetter(1))
137+
npoints_per_learner[index] += 1
138+
chosen_points.append((index, point))
139+
chosen_loss_improvements.append(loss_improvement)
140+
self.tell_pending((index, point))
141+
142+
return chosen_points, chosen_loss_improvements
135143

136144
def _ask_and_tell_based_on_loss(self, n):
137145
points = []
@@ -161,19 +169,11 @@ def _ask_and_tell_based_on_npoints(self, n):
161169

162170
def ask(self, n, tell_pending=True):
163171
"""Chose points for learners."""
164-
if any(l.npoints for l in self.learners):
165-
ask_and_tell = self._ask_and_tell
166-
else:
167-
# If there are no data points yet,
168-
# distribute the points over all learners.
169-
# See https://github.com/python-adaptive/adaptive/issues/159
170-
ask_and_tell = self._ask_and_tell_based_on_npoints
171-
172172
if not tell_pending:
173173
with restore(*self.learners):
174-
return ask_and_tell(n)
174+
return self._ask_and_tell(n)
175175
else:
176-
return ask_and_tell(n)
176+
return self._ask_and_tell(n)
177177

178178
def tell(self, x, y):
179179
index, x = x

0 commit comments

Comments
 (0)