@@ -121,6 +121,7 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
121
121
improvements_per_learner = []
122
122
points_per_learner = []
123
123
for index , learner in enumerate (self .learners ):
124
+ # Take the points from the cache
124
125
if index not in self ._points :
125
126
self ._points [index ] = learner .ask (
126
127
n = 1 , tell_pending = False )
@@ -155,27 +156,35 @@ def _ask_and_tell_based_on_loss(self, n):
155
156
+ len (l .pending_points ))
156
157
for i , l in enumerate (self .learners )]
157
158
priority = zip (losses , npoints )
158
- index , ( _ , _ ) = max (enumerate (priority ), key = itemgetter (1 ))
159
+ index = max (enumerate (priority ), key = itemgetter (1 ))[ 0 ]
159
160
npoints_per_learner [index ] += 1
160
- points , loss_improvements = self .learners [index ].ask (1 )
161
+
162
+ # Take the points from the cache
163
+ if index not in self ._points :
164
+ self ._points [index ] = self .learners [index ].ask (n = 1 )
165
+ points , loss_improvements = self ._points [index ]
166
+
161
167
chosen_points .append ((index , points [0 ]))
162
168
chosen_loss_improvements .append (loss_improvements [0 ])
163
169
return chosen_points , chosen_loss_improvements
164
170
165
171
def _ask_and_tell_based_on_npoints (self , n ):
166
- points = []
167
- loss_improvements = []
172
+ chosen_points = []
173
+ chosen_loss_improvements = []
168
174
npoints = [l .npoints + len (l .pending_points )
169
175
for l in self .learners ]
170
176
n_left = n
171
177
while n_left > 0 :
172
- i = np .argmin (npoints )
173
- xs , ls = self .learners [i ].ask (1 )
174
- npoints [i ] += 1
178
+ index = np .argmin (npoints )
179
+ # Take the points from the cache
180
+ if index not in self ._points :
181
+ self ._points [index ] = self .learners [index ].ask (n = 1 )
182
+ points , loss_improvements = self ._points [index ]
183
+ npoints [index ] += 1
175
184
n_left -= 1
176
- points .append ((i , xs [0 ]))
177
- loss_improvements .append (ls [0 ])
178
- return points , loss_improvements
185
+ chosen_points .append ((index , points [0 ]))
186
+ chosen_loss_improvements .append (loss_improvements [0 ])
187
+ return chosen_points , chosen_loss_improvements
179
188
180
189
def ask (self , n , tell_pending = True ):
181
190
"""Chose points for learners."""
0 commit comments