@@ -125,12 +125,14 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
125
125
self ._points [index ] = learner .ask (
126
126
n = 1 , tell_pending = False )
127
127
points , loss_improvements = self ._points [index ]
128
- npoints = npoints_per_learner [index ] + learner .npoints
128
+ npoints = (npoints_per_learner [index ]
129
+ + learner .npoints
130
+ + len (learner .pending_points ))
129
131
priority = (loss_improvements [0 ], - npoints )
130
132
improvements_per_learner .append (priority )
131
133
points_per_learner .append ((index , points [0 ]))
132
134
133
- # Chose the optimal improvement.
135
+ # Choose the optimal improvement.
134
136
(index , point ), (loss_improvement , _ ) = max (
135
137
zip (points_per_learner , improvements_per_learner ),
136
138
key = itemgetter (1 ))
@@ -142,15 +144,23 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
142
144
return chosen_points , chosen_loss_improvements
143
145
144
146
def _ask_and_tell_based_on_loss (self , n ):
145
- points = []
146
- loss_improvements = []
147
+ chosen_points = []
148
+ chosen_loss_improvements = []
149
+ npoints_per_learner = defaultdict (int )
150
+
147
151
for _ in range (n ):
148
152
losses = self ._losses (real = False )
149
- max_ind = np .argmax (losses )
150
- xs , ls = self .learners [max_ind ].ask (1 )
151
- points .append ((max_ind , xs [0 ]))
152
- loss_improvements .append (ls [0 ])
153
- return points , loss_improvements
153
+ npoints = [- (l .npoints
154
+ + npoints_per_learner [i ]
155
+ + len (l .pending_points ))
156
+ for i , l in enumerate (self .learners )]
157
+ priority = zip (losses , npoints )
158
+ index , (_ , _ ) = max (enumerate (priority ), key = itemgetter (1 ))
159
+ npoints_per_learner [index ] += 1
160
+ points , loss_improvements = self .learners [index ].ask (1 )
161
+ chosen_points .append ((index , points [0 ]))
162
+ chosen_loss_improvements .append (loss_improvements [0 ])
163
+ return chosen_points , chosen_loss_improvements
154
164
155
165
def _ask_and_tell_based_on_npoints (self , n ):
156
166
points = []
0 commit comments