Skip to content

Commit dac0b7d

Browse files
committed
fix point distribution for loss strategy
1 parent 53edb77 commit dac0b7d

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

adaptive/learner/balancing_learner.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,14 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
125125
self._points[index] = learner.ask(
126126
n=1, tell_pending=False)
127127
points, loss_improvements = self._points[index]
128-
npoints = npoints_per_learner[index] + learner.npoints
128+
npoints = (npoints_per_learner[index]
129+
+ learner.npoints
130+
+ len(learner.pending_points))
129131
priority = (loss_improvements[0], -npoints)
130132
improvements_per_learner.append(priority)
131133
points_per_learner.append((index, points[0]))
132134

133-
# Chose the optimal improvement.
135+
# Choose the optimal improvement.
134136
(index, point), (loss_improvement, _) = max(
135137
zip(points_per_learner, improvements_per_learner),
136138
key=itemgetter(1))
@@ -142,15 +144,23 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
142144
return chosen_points, chosen_loss_improvements
143145

144146
def _ask_and_tell_based_on_loss(self, n):
145-
points = []
146-
loss_improvements = []
147+
chosen_points = []
148+
chosen_loss_improvements = []
149+
npoints_per_learner = defaultdict(int)
150+
147151
for _ in range(n):
148152
losses = self._losses(real=False)
149-
max_ind = np.argmax(losses)
150-
xs, ls = self.learners[max_ind].ask(1)
151-
points.append((max_ind, xs[0]))
152-
loss_improvements.append(ls[0])
153-
return points, loss_improvements
153+
npoints = [-(l.npoints
154+
+ npoints_per_learner[i]
155+
+ len(l.pending_points))
156+
for i, l in enumerate(self.learners)]
157+
priority = zip(losses, npoints)
158+
index, (_, _) = max(enumerate(priority), key=itemgetter(1))
159+
npoints_per_learner[index] += 1
160+
points, loss_improvements = self.learners[index].ask(1)
161+
chosen_points.append((index, points[0]))
162+
chosen_loss_improvements.append(loss_improvements[0])
163+
return chosen_points, chosen_loss_improvements
154164

155165
def _ask_and_tell_based_on_npoints(self, n):
156166
points = []

0 commit comments

Comments
 (0)