@@ -123,6 +123,7 @@ def __init__(
123
123
else :
124
124
self .batch = (batch , batch )
125
125
126
+ self .num_particles = num_particles
126
127
self .log_num_particles = np .log (num_particles )
127
128
self .indices = list (range (2 , num_particles ))
128
129
self .len_indices = len (self .indices )
@@ -185,14 +186,10 @@ def astep(self, _):
185
186
186
187
_ , normalized_weights = self .normalize (particles )
187
188
# Get the new tree and update
188
- new_particle = np .random .choice (particles , p = normalized_weights )
189
- new_tree = new_particle .tree
190
-
191
- new_particle .log_weight = new_particle .old_likelihood_logp - self .log_num_particles
189
+ new_particle , new_tree = self .get_particle_tree (particles , normalized_weights )
192
190
self .all_particles [tree_id ] = new_particle
193
191
self .sum_trees = self .sum_trees_noi + new_tree ._predict ()
194
192
self .all_trees [tree_id ] = new_tree .trim ()
195
-
196
193
used_variates = new_tree .get_split_variables ()
197
194
198
195
if self .tune :
@@ -230,7 +227,7 @@ def resample(self, particles, normalized_weights):
230
227
231
228
Ensure particles are copied only if needed.
232
229
"""
233
- new_indices = systematic (normalized_weights )
230
+ new_indices = self . systematic (normalized_weights )
234
231
seen = []
235
232
new_particles = []
236
233
for idx in new_indices :
@@ -244,6 +241,29 @@ def resample(self, particles, normalized_weights):
244
241
245
242
return particles
246
243
244
+ def get_particle_tree (self , particles , normalized_weights ):
245
+ """
246
+ Sample a new particle, new tree and update log_weight
247
+ """
248
+ new_index = self .systematic (normalized_weights )[
249
+ discrete_uniform_sampler (self .num_particles )
250
+ ]
251
+ new_particle = particles [new_index - 2 ]
252
+ new_particle .log_weight = new_particle .old_likelihood_logp - self .log_num_particles
253
+ return new_particle , new_particle .tree
254
+
255
+ def systematic (self , normalized_weights ):
256
+ """
257
+ Systematic resampling.
258
+
259
+ Return indices in the range 2, ..., len(normalized_weights)+2
260
+
261
+ Note: adapted from https://github.com/nchopin/particles
262
+ """
263
+ lnw = len (normalized_weights )
264
+ single_uniform = (self .uniform .random () + np .arange (lnw )) / lnw
265
+ return inverse_cdf (single_uniform , normalized_weights ) + 2
266
+
247
267
def init_particles (self , tree_id : int ) -> np .ndarray :
248
268
"""Initialize particles."""
249
269
p0 = self .all_particles [tree_id ]
@@ -584,19 +604,6 @@ def update(self):
584
604
)
585
605
586
606
587
- def systematic (normalized_weights ):
588
- """
589
- Systematic resampling.
590
-
591
- Return indices in the range 2, ..., len(normalized_weights)+2
592
-
593
- Note: adapted from https://github.com/nchopin/particles
594
- """
595
- lnw = len (normalized_weights )
596
- single_uniform = (np .random .rand (1 ) + np .arange (lnw )) / lnw
597
- return inverse_cdf (single_uniform , normalized_weights ) + 2
598
-
599
-
600
607
@njit
601
608
def inverse_cdf (single_uniform , normalized_weights ):
602
609
"""
0 commit comments