@@ -113,25 +113,33 @@ def strategy(self, strategy):
113
113
' strategy="npoints" is implemented.' )
114
114
115
115
def _ask_and_tell_based_on_loss_improvements (self , n ):
116
- points = []
117
- loss_improvements = []
116
+ chosen_points = []
117
+ chosen_loss_improvements = []
118
+ npoints_per_learner = defaultdict (int )
119
+
118
120
for _ in range (n ):
119
121
improvements_per_learner = []
120
- pairs = []
122
+ points_per_learner = []
121
123
for index , learner in enumerate (self .learners ):
122
124
if index not in self ._points :
123
125
self ._points [index ] = learner .ask (
124
126
n = 1 , tell_pending = False )
125
- point , loss_improvement = self ._points [index ]
126
- improvements_per_learner .append (loss_improvement [0 ])
127
- pairs .append ((index , point [0 ]))
128
- x , l = max (zip (pairs , improvements_per_learner ),
129
- key = itemgetter (1 ))
130
- points .append (x )
131
- loss_improvements .append (l )
132
- self .tell_pending (x )
133
-
134
- return points , loss_improvements
127
+ points , loss_improvements = self ._points [index ]
128
+ npoints = npoints_per_learner [index ] + learner .npoints
129
+ priority = (loss_improvements [0 ], - npoints )
130
+ improvements_per_learner .append (priority )
131
+ points_per_learner .append ((index , points [0 ]))
132
+
133
+ # Chose the optimal improvement.
134
+ (index , point ), (loss_improvement , _ ) = max (
135
+ zip (points_per_learner , improvements_per_learner ),
136
+ key = itemgetter (1 ))
137
+ npoints_per_learner [index ] += 1
138
+ chosen_points .append ((index , point ))
139
+ chosen_loss_improvements .append (loss_improvement )
140
+ self .tell_pending ((index , point ))
141
+
142
+ return chosen_points , chosen_loss_improvements
135
143
136
144
def _ask_and_tell_based_on_loss (self , n ):
137
145
points = []
@@ -161,19 +169,11 @@ def _ask_and_tell_based_on_npoints(self, n):
161
169
162
170
def ask (self , n , tell_pending = True ):
163
171
"""Chose points for learners."""
164
- if any (l .npoints for l in self .learners ):
165
- ask_and_tell = self ._ask_and_tell
166
- else :
167
- # If there are no data points yet,
168
- # distribute the points over all learners.
169
- # See https://github.com/python-adaptive/adaptive/issues/159
170
- ask_and_tell = self ._ask_and_tell_based_on_npoints
171
-
172
172
if not tell_pending :
173
173
with restore (* self .learners ):
174
- return ask_and_tell (n )
174
+ return self . _ask_and_tell (n )
175
175
else :
176
- return ask_and_tell (n )
176
+ return self . _ask_and_tell (n )
177
177
178
178
def tell (self , x , y ):
179
179
index , x = x
0 commit comments