Skip to content

Commit b2d13a9

Browse files
committed
use a list for ((learner_index, point), loss_improvement) tuples
1 parent d60a88e commit b2d13a9

File tree

1 file changed

+23
-29
lines changed

1 file changed

+23
-29
lines changed

adaptive/learner/balancing_learner.py

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -113,40 +113,34 @@ def strategy(self, strategy):
113113
' strategy="npoints" is implemented.')
114114

115115
def _ask_and_tell_based_on_loss_improvements(self, n):
116-
chosen_points = []
117-
chosen_loss_improvements = []
118-
npoints = [l.npoints + len(l.pending_points)
119-
for l in self.learners]
116+
selected = [] # tuples ((learner_index, point), loss_improvement)
117+
npoints = [l.npoints + len(l.pending_points) for l in self.learners]
120118
for _ in range(n):
121-
improvements_per_learner = []
122-
points_per_learner = []
119+
to_select = []
123120
for index, learner in enumerate(self.learners):
124121
# Take the points from the cache
125122
if index not in self._ask_cache:
126123
self._ask_cache[index] = learner.ask(
127124
n=1, tell_pending=False)
128125
points, loss_improvements = self._ask_cache[index]
129-
130-
priority = (loss_improvements[0], -npoints[index])
131-
improvements_per_learner.append(priority)
132-
points_per_learner.append((index, points[0]))
126+
to_select.append(
127+
((index, points[0]),
128+
(loss_improvements[0], -npoints[index]))
129+
)
133130

134131
# Choose the optimal improvement.
135132
(index, point), (loss_improvement, _) = max(
136-
zip(points_per_learner, improvements_per_learner),
137-
key=itemgetter(1))
133+
to_select, key=itemgetter(1))
138134
npoints[index] += 1
139-
chosen_points.append((index, point))
140-
chosen_loss_improvements.append(loss_improvement)
135+
selected.append(((index, point), loss_improvement))
141136
self.tell_pending((index, point))
142137

143-
return chosen_points, chosen_loss_improvements
138+
points, loss_improvements = map(list, zip(*selected))
139+
return points, loss_improvements
144140

145141
def _ask_and_tell_based_on_loss(self, n):
146-
chosen_points = []
147-
chosen_loss_improvements = []
148-
npoints = [l.npoints + len(l.pending_points)
149-
for l in self.learners]
142+
selected = [] # tuples ((learner_index, point), loss_improvement)
143+
npoints = [l.npoints + len(l.pending_points) for l in self.learners]
150144
for _ in range(n):
151145
losses = self._losses(real=False)
152146
priority = zip(losses, (-n for n in npoints))
@@ -158,15 +152,14 @@ def _ask_and_tell_based_on_loss(self, n):
158152
self._ask_cache[index] = self.learners[index].ask(n=1)
159153
points, loss_improvements = self._ask_cache[index]
160154

161-
chosen_points.append((index, points[0]))
162-
chosen_loss_improvements.append(loss_improvements[0])
163-
return chosen_points, chosen_loss_improvements
155+
selected.append(((index, points[0]), loss_improvements[0]))
156+
157+
points, loss_improvements = map(list, zip(*selected))
158+
return points, loss_improvements
164159

165160
def _ask_and_tell_based_on_npoints(self, n):
166-
chosen_points = []
167-
chosen_loss_improvements = []
168-
npoints = [l.npoints + len(l.pending_points)
169-
for l in self.learners]
161+
selected = [] # tuples ((learner_index, point), loss_improvement)
162+
npoints = [l.npoints + len(l.pending_points) for l in self.learners]
170163
n_left = n
171164
while n_left > 0:
172165
index = np.argmin(npoints)
@@ -176,9 +169,10 @@ def _ask_and_tell_based_on_npoints(self, n):
176169
points, loss_improvements = self._ask_cache[index]
177170
npoints[index] += 1
178171
n_left -= 1
179-
chosen_points.append((index, points[0]))
180-
chosen_loss_improvements.append(loss_improvements[0])
181-
return chosen_points, chosen_loss_improvements
172+
selected.append(((index, points[0]), loss_improvements[0]))
173+
174+
points, loss_improvements = map(list, zip(*selected))
175+
return points, loss_improvements
182176

183177
def ask(self, n, tell_pending=True):
184178
"""Chose points for learners."""

0 commit comments

Comments
 (0)