Skip to content

Commit d9ceaf7

Browse files
committed
make the ask functions more similar
1 parent 1401a9a commit d9ceaf7

File tree

1 file changed

+9
-15
lines changed

1 file changed

+9
-15
lines changed

adaptive/learner/balancing_learner.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ def strategy(self, strategy):
115115
def _ask_and_tell_based_on_loss_improvements(self, n):
116116
chosen_points = []
117117
chosen_loss_improvements = []
118-
npoints_per_learner = defaultdict(int)
119-
118+
npoints = [l.npoints + len(l.pending_points)
119+
for l in self.learners]
120120
for _ in range(n):
121121
improvements_per_learner = []
122122
points_per_learner = []
@@ -126,18 +126,16 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
126126
self._points[index] = learner.ask(
127127
n=1, tell_pending=False)
128128
points, loss_improvements = self._points[index]
129-
npoints = (npoints_per_learner[index]
130-
+ learner.npoints
131-
+ len(learner.pending_points))
132-
priority = (loss_improvements[0], -npoints)
129+
130+
priority = (loss_improvements[0], -npoints[index])
133131
improvements_per_learner.append(priority)
134132
points_per_learner.append((index, points[0]))
135133

136134
# Choose the optimal improvement.
137135
(index, point), (loss_improvement, _) = max(
138136
zip(points_per_learner, improvements_per_learner),
139137
key=itemgetter(1))
140-
npoints_per_learner[index] += 1
138+
npoints[index] += 1
141139
chosen_points.append((index, point))
142140
chosen_loss_improvements.append(loss_improvement)
143141
self.tell_pending((index, point))
@@ -147,17 +145,13 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
147145
def _ask_and_tell_based_on_loss(self, n):
148146
chosen_points = []
149147
chosen_loss_improvements = []
150-
npoints_per_learner = defaultdict(int)
151-
148+
npoints = [l.npoints + len(l.pending_points)
149+
for l in self.learners]
152150
for _ in range(n):
153151
losses = self._losses(real=False)
154-
npoints = [-(l.npoints
155-
+ npoints_per_learner[i]
156-
+ len(l.pending_points))
157-
for i, l in enumerate(self.learners)]
158-
priority = zip(losses, npoints)
152+
priority = zip(losses, (-n for n in npoints))
159153
index = max(enumerate(priority), key=itemgetter(1))[0]
160-
npoints_per_learner[index] += 1
154+
npoints[index] += 1
161155

162156
# Take the points from the cache
163157
if index not in self._points:

0 commit comments

Comments
 (0)