@@ -146,17 +146,19 @@ def _ask_for_more_samples(self, x: Number, n: int) -> Tuple[Points, List[float]]
146
146
need to be resampled many more times"""
147
147
n_existing = self ._number_samples .get (x , 0 )
148
148
points = [(seed + n_existing , x ) for seed in range (n )]
149
-
150
- loss_improvements = [0 ] * n # We set the loss_improvements of resamples to 0
149
+ xl , xr = self .neighbors_combined [x ]
150
+ loss = (self .losses_combined [xl , x ] + self .losses_combined [x , xr ]) / 2
151
+ loss_improvement = loss - loss * np .sqrt (n_existing ) / np .sqrt (n_existing + n )
152
+ loss_improvements = [loss_improvement / n ] * n
151
153
return points , loss_improvements
152
154
153
155
def _ask_for_new_point (self , n : int ) -> Tuple [Points , List [float ]]:
154
156
"""When asking for n new points, the learner returns n times a single
155
157
new point, since in general n << min_samples and this point will need
156
158
to be resampled many more times"""
157
- points , loss_improvements = self ._ask_points_without_adding (1 )
159
+ points , ( loss_improvement ,) = self ._ask_points_without_adding (1 )
158
160
points = [(seed , x ) for seed , x in zip (range (n ), n * points )]
159
- loss_improvements = loss_improvements + [ 0 ] * ( n - 1 )
161
+ loss_improvements = [ loss_improvement / n ] * n
160
162
return points , loss_improvements
161
163
162
164
def tell_pending (self , seed_x : Point ) -> None :
0 commit comments