14
14
15
15
import logging
16
16
17
- from copy import copy
17
+ from copy import copy , deepcopy
18
+ from numba import jit
18
19
19
20
import aesara
20
21
import numpy as np
@@ -173,11 +174,8 @@ def astep(self, _):
173
174
# Normalize weights
174
175
w_t , normalized_weights = self .normalize (particles [2 :])
175
176
176
- # Resample all but first two particles
177
- new_indices = np .random .choice (
178
- self .indices , size = self .len_indices , p = normalized_weights
179
- )
180
- particles [2 :] = particles [new_indices ]
177
+ # Resample
178
+ particles = self .resample (particles , normalized_weights )
181
179
182
180
# Set the new weight
183
181
for p in particles [2 :]:
@@ -196,12 +194,14 @@ def astep(self, _):
196
194
self .sum_trees = self .sum_trees_noi + new_tree ._predict ()
197
195
self .all_trees [tree_id ] = new_tree .trim ()
198
196
197
+ used_variates = new_tree .get_split_variables ()
198
+
199
199
if self .tune :
200
200
self .ssv = SampleSplittingVariable (self .alpha_vec )
201
- for index in new_particle . used_variates :
201
+ for index in used_variates :
202
202
self .alpha_vec [index ] += 1
203
203
else :
204
- for index in new_particle . used_variates :
204
+ for index in used_variates :
205
205
variable_inclusion [index ] += 1
206
206
207
207
stats = {"variable_inclusion" : variable_inclusion , "bart_trees" : copy (self .all_trees )}
@@ -225,6 +225,21 @@ def normalize(self, particles):
225
225
226
226
return w_t , normalized_weights
227
227
228
+ def resample (self , particles , normalized_weights ):
229
+ """Use systematic resample for all but first two particles"""
230
+ seen = []
231
+ new_particles = []
232
+ for idx in new_indices :
233
+ if idx in seen :
234
+ new_particles .append (deepcopy (particles [idx ]))
235
+ else :
236
+ seen .append (idx )
237
+ new_particles .append (particles [idx ])
238
+
239
+ particles [2 :] = new_particles
240
+
241
+ return particles
242
+
228
243
def init_particles (self , tree_id : int ) -> np .ndarray :
229
244
"""Initialize particles."""
230
245
p0 = self .all_particles [tree_id ]
@@ -286,9 +301,9 @@ def __init__(self, tree):
286
301
self .expansion_nodes = [0 ]
287
302
self .log_weight = 0
288
303
self .old_likelihood_logp = 0
289
- self .used_variates = []
290
304
self .kf = 0.75
291
305
306
+
292
307
def sample_tree (
293
308
self ,
294
309
ssv ,
@@ -326,7 +341,6 @@ def sample_tree(
326
341
if index_selected_predictor is not None :
327
342
new_indexes = self .tree .idx_leaf_nodes [- 2 :]
328
343
self .expansion_nodes .extend (new_indexes )
329
- self .used_variates .append (index_selected_predictor )
330
344
tree_grew = True
331
345
332
346
return tree_grew
@@ -524,7 +538,7 @@ def mean(a):
524
538
525
539
return mean
526
540
527
-
541
+ @ jit ()
528
542
def discrete_uniform_sampler (upper_value ):
529
543
"""Draw from the uniform distribution with bounds [0, upper_value).
530
544
@@ -541,14 +555,15 @@ def __init__(self, scale, shape):
541
555
self .scale = scale
542
556
self .shape = shape
543
557
self .update ()
544
-
558
+
545
559
def random (self ):
546
560
if self .idx == self .size :
547
561
self .update ()
548
562
pop = self .cache [:, self .idx ]
549
563
self .idx += 1
550
564
return pop
551
565
566
+
552
567
def update (self ):
553
568
self .idx = 0
554
569
self .cache = np .random .normal (loc = 0.0 , scale = self .scale , size = (self .shape , self .size ))
@@ -571,12 +586,47 @@ def random(self):
571
586
self .idx += 1
572
587
return pop
573
588
589
+
574
590
def update (self ):
575
591
self .idx = 0
576
592
self .cache = np .random .uniform (
577
593
self .lower_bound , self .upper_bound , size = (self .shape , self .size )
578
594
)
579
595
596
+ @jit ()
597
+ def systematic (W ):
598
+ """Systematic resampling.
599
+ """
600
+ M = len (W )
601
+ su = (np .random .rand (1 ) + np .arange (M )) / M
602
+ return inverse_cdf (su , W ) + 2
603
+
604
+
605
+ @jit (nopython = True )
606
+ def inverse_cdf (su , W ):
607
+ """Inverse CDF algorithm for a finite distribution.
608
+ Parameters
609
+ ----------
610
+ su: (M,) ndarray
611
+ M sorted uniform variates (i.e. M ordered points in [0,1]).
612
+ W: (N,) ndarray
613
+ a vector of N normalized weights (>=0 and sum to one)
614
+ Returns
615
+ -------
616
+ A: (M,) ndarray
617
+ a vector of M indices in range 0, ..., N-1
618
+ """
619
+ j = 0
620
+ s = W [0 ]
621
+ M = su .shape [0 ]
622
+ A = np .empty (M , dtype = np .int64 )
623
+ for n in range (M ):
624
+ while su [n ] > s :
625
+ j += 1
626
+ s += W [j ]
627
+ A [n ] = j
628
+ return A
629
+
580
630
581
631
def logp (point , out_vars , vars , shared ):
582
632
"""Compile Aesara function of the model and the input and output variables.
0 commit comments