@@ -115,8 +115,8 @@ def strategy(self, strategy):
115
115
def _ask_and_tell_based_on_loss_improvements (self , n ):
116
116
chosen_points = []
117
117
chosen_loss_improvements = []
118
- npoints_per_learner = defaultdict ( int )
119
-
118
+ npoints = [ l . npoints + len ( l . pending_points )
119
+ for l in self . learners ]
120
120
for _ in range (n ):
121
121
improvements_per_learner = []
122
122
points_per_learner = []
@@ -126,18 +126,16 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
126
126
self ._points [index ] = learner .ask (
127
127
n = 1 , tell_pending = False )
128
128
points , loss_improvements = self ._points [index ]
129
- npoints = (npoints_per_learner [index ]
130
- + learner .npoints
131
- + len (learner .pending_points ))
132
- priority = (loss_improvements [0 ], - npoints )
129
+
130
+ priority = (loss_improvements [0 ], - npoints [index ])
133
131
improvements_per_learner .append (priority )
134
132
points_per_learner .append ((index , points [0 ]))
135
133
136
134
# Choose the optimal improvement.
137
135
(index , point ), (loss_improvement , _ ) = max (
138
136
zip (points_per_learner , improvements_per_learner ),
139
137
key = itemgetter (1 ))
140
- npoints_per_learner [index ] += 1
138
+ npoints [index ] += 1
141
139
chosen_points .append ((index , point ))
142
140
chosen_loss_improvements .append (loss_improvement )
143
141
self .tell_pending ((index , point ))
@@ -147,17 +145,13 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
147
145
def _ask_and_tell_based_on_loss (self , n ):
148
146
chosen_points = []
149
147
chosen_loss_improvements = []
150
- npoints_per_learner = defaultdict ( int )
151
-
148
+ npoints = [ l . npoints + len ( l . pending_points )
149
+ for l in self . learners ]
152
150
for _ in range (n ):
153
151
losses = self ._losses (real = False )
154
- npoints = [- (l .npoints
155
- + npoints_per_learner [i ]
156
- + len (l .pending_points ))
157
- for i , l in enumerate (self .learners )]
158
- priority = zip (losses , npoints )
152
+ priority = zip (losses , (- n for n in npoints ))
159
153
index = max (enumerate (priority ), key = itemgetter (1 ))[0 ]
160
- npoints_per_learner [index ] += 1
154
+ npoints [index ] += 1
161
155
162
156
# Take the points from the cache
163
157
if index not in self ._points :
0 commit comments