14
14
15
15
import logging
16
16
17
- from copy import copy
17
+ from copy import deepcopy
18
+ from numba import njit
18
19
19
20
import aesara
20
21
import numpy as np
@@ -56,7 +57,7 @@ class PGBART(ArrayStepShared):
56
57
def __init__ (
57
58
self ,
58
59
vars = None ,
59
- num_particles = 40 ,
60
+ num_particles = 20 ,
60
61
batch = "auto" ,
61
62
model = None ,
62
63
):
@@ -104,8 +105,6 @@ def __init__(
104
105
idx_data_points = np .arange (self .num_observations , dtype = "int32" ),
105
106
shape = self .shape ,
106
107
)
107
- self .mean = fast_mean ()
108
-
109
108
self .normal = NormalSampler (mu_std , self .shape )
110
109
self .uniform = UniformSampler (0.33 , 0.75 , self .shape )
111
110
self .prior_prob_leaf_node = compute_prior_probability (self .alpha )
@@ -158,7 +157,6 @@ def astep(self, _):
158
157
self .X ,
159
158
self .missing_data ,
160
159
self .sum_trees ,
161
- self .mean ,
162
160
self .m ,
163
161
self .normal ,
164
162
self .shape ,
@@ -173,11 +171,8 @@ def astep(self, _):
173
171
# Normalize weights
174
172
w_t , normalized_weights = self .normalize (particles [2 :])
175
173
176
- # Resample all but first two particles
177
- new_indices = np .random .choice (
178
- self .indices , size = self .len_indices , p = normalized_weights
179
- )
180
- particles [2 :] = particles [new_indices ]
174
+ # Resample
175
+ particles = self .resample (particles , normalized_weights )
181
176
182
177
# Set the new weight
183
178
for p in particles [2 :]:
@@ -196,15 +191,17 @@ def astep(self, _):
196
191
self .sum_trees = self .sum_trees_noi + new_tree ._predict ()
197
192
self .all_trees [tree_id ] = new_tree .trim ()
198
193
194
+ used_variates = new_tree .get_split_variables ()
195
+
199
196
if self .tune :
200
197
self .ssv = SampleSplittingVariable (self .alpha_vec )
201
- for index in new_particle . used_variates :
198
+ for index in used_variates :
202
199
self .alpha_vec [index ] += 1
203
200
else :
204
- for index in new_particle . used_variates :
201
+ for index in used_variates :
205
202
variable_inclusion [index ] += 1
206
203
207
- stats = {"variable_inclusion" : variable_inclusion , "bart_trees" : copy ( self .all_trees ) }
204
+ stats = {"variable_inclusion" : variable_inclusion , "bart_trees" : self .all_trees }
208
205
return self .sum_trees , [stats ]
209
206
210
207
def normalize (self , particles ):
@@ -225,18 +222,36 @@ def normalize(self, particles):
225
222
226
223
return w_t , normalized_weights
227
224
225
+ def resample (self , particles , normalized_weights ):
226
+ """
227
+ Use systematic resample for all but first two particles
228
+
229
+ Ensure particles are copied only if needed.
230
+ """
231
+ new_indices = systematic (normalized_weights )
232
+ seen = []
233
+ new_particles = []
234
+ for idx in new_indices :
235
+ if idx in seen :
236
+ new_particles .append (deepcopy (particles [idx ]))
237
+ else :
238
+ new_particles .append (particles [idx ])
239
+ seen .append (idx )
240
+
241
+ particles [2 :] = new_particles
242
+
243
+ return particles
244
+
228
245
def init_particles (self , tree_id : int ) -> np .ndarray :
229
246
"""Initialize particles."""
230
247
p0 = self .all_particles [tree_id ]
231
- p1 = copy (p0 )
248
+ p1 = deepcopy (p0 )
232
249
p1 .sample_leafs (
233
250
self .sum_trees ,
234
- self .mean ,
235
251
self .m ,
236
252
self .normal ,
237
253
self .shape ,
238
254
)
239
-
240
255
# The old tree and the one with new leafs do not grow so we update the weights only once
241
256
self .update_weight (p0 , old = True )
242
257
self .update_weight (p1 , old = True )
@@ -286,7 +301,6 @@ def __init__(self, tree):
286
301
self .expansion_nodes = [0 ]
287
302
self .log_weight = 0
288
303
self .old_likelihood_logp = 0
289
- self .used_variates = []
290
304
self .kf = 0.75
291
305
292
306
def sample_tree (
@@ -297,7 +311,6 @@ def sample_tree(
297
311
X ,
298
312
missing_data ,
299
313
sum_trees ,
300
- mean ,
301
314
m ,
302
315
normal ,
303
316
shape ,
@@ -317,7 +330,6 @@ def sample_tree(
317
330
X ,
318
331
missing_data ,
319
332
sum_trees ,
320
- mean ,
321
333
m ,
322
334
normal ,
323
335
self .kf ,
@@ -326,20 +338,18 @@ def sample_tree(
326
338
if index_selected_predictor is not None :
327
339
new_indexes = self .tree .idx_leaf_nodes [- 2 :]
328
340
self .expansion_nodes .extend (new_indexes )
329
- self .used_variates .append (index_selected_predictor )
330
341
tree_grew = True
331
342
332
343
return tree_grew
333
344
334
- def sample_leafs (self , sum_trees , mean , m , normal , shape ):
345
+ def sample_leafs (self , sum_trees , m , normal , shape ):
335
346
336
347
for idx in self .tree .idx_leaf_nodes :
337
348
if idx > 0 :
338
349
leaf = self .tree [idx ]
339
350
idx_data_points = leaf .idx_data_points
340
351
node_value = draw_leaf_value (
341
352
sum_trees [:, idx_data_points ],
342
- mean ,
343
353
m ,
344
354
normal ,
345
355
self .kf ,
@@ -400,7 +410,6 @@ def grow_tree(
400
410
X ,
401
411
missing_data ,
402
412
sum_trees ,
403
- mean ,
404
413
m ,
405
414
normal ,
406
415
kf ,
@@ -429,7 +438,6 @@ def grow_tree(
429
438
idx_data_point = new_idx_data_points [idx ]
430
439
node_value = draw_leaf_value (
431
440
sum_trees [:, idx_data_point ],
432
- mean ,
433
441
m ,
434
442
normal ,
435
443
kf ,
@@ -482,7 +490,7 @@ def get_split_value(available_splitting_values, idx_data_points, missing_data):
482
490
return split_value
483
491
484
492
485
- def draw_leaf_value (Y_mu_pred , mean , m , normal , kf , shape ):
493
+ def draw_leaf_value (Y_mu_pred , m , normal , kf , shape ):
486
494
"""Draw Gaussian distributed leaf values."""
487
495
if Y_mu_pred .size == 0 :
488
496
return np .zeros (shape )
@@ -491,38 +499,29 @@ def draw_leaf_value(Y_mu_pred, mean, m, normal, kf, shape):
491
499
if Y_mu_pred .size == 1 :
492
500
mu_mean = np .full (shape , Y_mu_pred .item () / m )
493
501
else :
494
- mu_mean = mean (Y_mu_pred ) / m
502
+ mu_mean = fast_mean (Y_mu_pred ) / m
495
503
496
504
draw = norm + mu_mean
497
505
return draw
498
506
499
507
500
- def fast_mean ():
501
- """If available use Numba to speed up the computation of the mean."""
502
- try :
503
- from numba import jit
504
- except ImportError :
505
- from functools import partial
506
-
507
- return partial ( np . mean , axis = 1 )
508
-
509
- @ jit
510
- def mean ( a ) :
511
- if a . ndim == 1 :
512
- count = a .shape [0 ]
513
- suma = 0
508
+ @ njit
509
+ def fast_mean ( a ):
510
+ """Use Numba to speed up the computation of the mean."""
511
+
512
+ if a . ndim == 1 :
513
+ count = a . shape [ 0 ]
514
+ suma = 0
515
+ for i in range ( count ):
516
+ suma += a [ i ]
517
+ return suma / count
518
+ elif a . ndim == 2 :
519
+ res = np . zeros ( a . shape [ 0 ])
520
+ count = a .shape [1 ]
521
+ for j in range ( a . shape [ 0 ]):
514
522
for i in range (count ):
515
- suma += a [i ]
516
- return suma / count
517
- elif a .ndim == 2 :
518
- res = np .zeros (a .shape [0 ])
519
- count = a .shape [1 ]
520
- for j in range (a .shape [0 ]):
521
- for i in range (count ):
522
- res [j ] += a [j , i ]
523
- return res / count
524
-
525
- return mean
523
+ res [j ] += a [j , i ]
524
+ return res / count
526
525
527
526
528
527
def discrete_uniform_sampler (upper_value ):
@@ -578,6 +577,51 @@ def update(self):
578
577
)
579
578
580
579
580
+ def systematic (normalized_weights ):
581
+ """
582
+ Systematic resampling.
583
+
584
+ Return indices in the range 2, ..., len(normalized_weights)+2
585
+
586
+ Note: adapted from https://github.com/nchopin/particles
587
+ """
588
+ lnw = len (normalized_weights )
589
+ single_uniform = (np .random .rand (1 ) + np .arange (lnw )) / lnw
590
+ return inverse_cdf (single_uniform , normalized_weights ) + 2
591
+
592
+
593
+ @njit
594
+ def inverse_cdf (single_uniform , normalized_weights ):
595
+ """
596
+ Inverse CDF algorithm for a finite distribution.
597
+
598
+ Parameters
599
+ ----------
600
+ single_uniform: ndarray
601
+ ordered points in [0,1]
602
+
603
+ normalized_weights: ndarray
604
+ normalized weights
605
+
606
+ Returns
607
+ -------
608
+ A: ndarray
609
+ a vector of indices in range 2, ..., len(normalized_weights)+2
610
+
611
+ Note: adapted from https://github.com/nchopin/particles
612
+ """
613
+ j = 0
614
+ s = normalized_weights [0 ]
615
+ M = single_uniform .shape [0 ]
616
+ A = np .empty (M , dtype = np .int64 )
617
+ for n in range (M ):
618
+ while single_uniform [n ] > s :
619
+ j += 1
620
+ s += normalized_weights [j ]
621
+ A [n ] = j
622
+ return A
623
+
624
+
581
625
def logp (point , out_vars , vars , shared ):
582
626
"""Compile Aesara function of the model and the input and output variables.
583
627
0 commit comments