Skip to content

Commit 1401a9a

Browse files
committed
use the cache for all strategies
1 parent dac0b7d commit 1401a9a

File tree

1 file changed

+19
-10
lines changed

1 file changed

+19
-10
lines changed

adaptive/learner/balancing_learner.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
121121
improvements_per_learner = []
122122
points_per_learner = []
123123
for index, learner in enumerate(self.learners):
124+
# Take the points from the cache
124125
if index not in self._points:
125126
self._points[index] = learner.ask(
126127
n=1, tell_pending=False)
@@ -155,27 +156,35 @@ def _ask_and_tell_based_on_loss(self, n):
155156
+ len(l.pending_points))
156157
for i, l in enumerate(self.learners)]
157158
priority = zip(losses, npoints)
158-
index, (_, _) = max(enumerate(priority), key=itemgetter(1))
159+
index = max(enumerate(priority), key=itemgetter(1))[0]
159160
npoints_per_learner[index] += 1
160-
points, loss_improvements = self.learners[index].ask(1)
161+
162+
# Take the points from the cache
163+
if index not in self._points:
164+
self._points[index] = self.learners[index].ask(n=1)
165+
points, loss_improvements = self._points[index]
166+
161167
chosen_points.append((index, points[0]))
162168
chosen_loss_improvements.append(loss_improvements[0])
163169
return chosen_points, chosen_loss_improvements
164170

165171
def _ask_and_tell_based_on_npoints(self, n):
166-
points = []
167-
loss_improvements = []
172+
chosen_points = []
173+
chosen_loss_improvements = []
168174
npoints = [l.npoints + len(l.pending_points)
169175
for l in self.learners]
170176
n_left = n
171177
while n_left > 0:
172-
i = np.argmin(npoints)
173-
xs, ls = self.learners[i].ask(1)
174-
npoints[i] += 1
178+
index = np.argmin(npoints)
179+
# Take the points from the cache
180+
if index not in self._points:
181+
self._points[index] = self.learners[index].ask(n=1)
182+
points, loss_improvements = self._points[index]
183+
npoints[index] += 1
175184
n_left -= 1
176-
points.append((i, xs[0]))
177-
loss_improvements.append(ls[0])
178-
return points, loss_improvements
185+
chosen_points.append((index, points[0]))
186+
chosen_loss_improvements.append(loss_improvements[0])
187+
return chosen_points, chosen_loss_improvements
179188

180189
def ask(self, n, tell_pending=True):
181190
"""Chose points for learners."""

0 commit comments

Comments
 (0)