14
14
15
15
import logging
16
16
17
- from copy import copy , deepcopy
18
- from numba import jit
17
+ from copy import deepcopy
18
+ from numba import njit
19
19
20
20
import aesara
21
21
import numpy as np
@@ -57,7 +57,7 @@ class PGBART(ArrayStepShared):
57
57
def __init__ (
58
58
self ,
59
59
vars = None ,
60
- num_particles = 40 ,
60
+ num_particles = 20 ,
61
61
batch = "auto" ,
62
62
model = None ,
63
63
):
@@ -105,8 +105,6 @@ def __init__(
105
105
idx_data_points = np .arange (self .num_observations , dtype = "int32" ),
106
106
shape = self .shape ,
107
107
)
108
- self .mean = fast_mean ()
109
-
110
108
self .normal = NormalSampler (mu_std , self .shape )
111
109
self .uniform = UniformSampler (0.33 , 0.75 , self .shape )
112
110
self .prior_prob_leaf_node = compute_prior_probability (self .alpha )
@@ -159,7 +157,6 @@ def astep(self, _):
159
157
self .X ,
160
158
self .missing_data ,
161
159
self .sum_trees ,
162
- self .mean ,
163
160
self .m ,
164
161
self .normal ,
165
162
self .shape ,
@@ -204,7 +201,7 @@ def astep(self, _):
204
201
for index in used_variates :
205
202
variable_inclusion [index ] += 1
206
203
207
- stats = {"variable_inclusion" : variable_inclusion , "bart_trees" : copy ( self .all_trees ) }
204
+ stats = {"variable_inclusion" : variable_inclusion , "bart_trees" : self .all_trees }
208
205
return self .sum_trees , [stats ]
209
206
210
207
def normalize (self , particles ):
@@ -226,15 +223,20 @@ def normalize(self, particles):
226
223
return w_t , normalized_weights
227
224
228
225
def resample (self , particles , normalized_weights ):
229
- """Use systematic resample for all but first two particles"""
226
+ """
227
+ Use systematic resample for all but first two particles
228
+
229
+ Ensure particles are copied only if needed.
230
+ """
231
+ new_indices = systematic (normalized_weights )
230
232
seen = []
231
233
new_particles = []
232
234
for idx in new_indices :
233
235
if idx in seen :
234
236
new_particles .append (deepcopy (particles [idx ]))
235
237
else :
236
- seen .append (idx )
237
238
new_particles .append (particles [idx ])
239
+ seen .append (idx )
238
240
239
241
particles [2 :] = new_particles
240
242
@@ -243,15 +245,13 @@ def resample(self, particles, normalized_weights):
243
245
def init_particles (self , tree_id : int ) -> np .ndarray :
244
246
"""Initialize particles."""
245
247
p0 = self .all_particles [tree_id ]
246
- p1 = copy (p0 )
248
+ p1 = deepcopy (p0 )
247
249
p1 .sample_leafs (
248
250
self .sum_trees ,
249
- self .mean ,
250
251
self .m ,
251
252
self .normal ,
252
253
self .shape ,
253
254
)
254
-
255
255
# The old tree and the one with new leafs do not grow so we update the weights only once
256
256
self .update_weight (p0 , old = True )
257
257
self .update_weight (p1 , old = True )
@@ -303,7 +303,6 @@ def __init__(self, tree):
303
303
self .old_likelihood_logp = 0
304
304
self .kf = 0.75
305
305
306
-
307
306
def sample_tree (
308
307
self ,
309
308
ssv ,
@@ -312,7 +311,6 @@ def sample_tree(
312
311
X ,
313
312
missing_data ,
314
313
sum_trees ,
315
- mean ,
316
314
m ,
317
315
normal ,
318
316
shape ,
@@ -332,7 +330,6 @@ def sample_tree(
332
330
X ,
333
331
missing_data ,
334
332
sum_trees ,
335
- mean ,
336
333
m ,
337
334
normal ,
338
335
self .kf ,
@@ -345,15 +342,14 @@ def sample_tree(
345
342
346
343
return tree_grew
347
344
348
- def sample_leafs (self , sum_trees , mean , m , normal , shape ):
345
+ def sample_leafs (self , sum_trees , m , normal , shape ):
349
346
350
347
for idx in self .tree .idx_leaf_nodes :
351
348
if idx > 0 :
352
349
leaf = self .tree [idx ]
353
350
idx_data_points = leaf .idx_data_points
354
351
node_value = draw_leaf_value (
355
352
sum_trees [:, idx_data_points ],
356
- mean ,
357
353
m ,
358
354
normal ,
359
355
self .kf ,
@@ -414,7 +410,6 @@ def grow_tree(
414
410
X ,
415
411
missing_data ,
416
412
sum_trees ,
417
- mean ,
418
413
m ,
419
414
normal ,
420
415
kf ,
@@ -443,7 +438,6 @@ def grow_tree(
443
438
idx_data_point = new_idx_data_points [idx ]
444
439
node_value = draw_leaf_value (
445
440
sum_trees [:, idx_data_point ],
446
- mean ,
447
441
m ,
448
442
normal ,
449
443
kf ,
@@ -496,7 +490,7 @@ def get_split_value(available_splitting_values, idx_data_points, missing_data):
496
490
return split_value
497
491
498
492
499
- def draw_leaf_value (Y_mu_pred , mean , m , normal , kf , shape ):
493
+ def draw_leaf_value (Y_mu_pred , m , normal , kf , shape ):
500
494
"""Draw Gaussian distributed leaf values."""
501
495
if Y_mu_pred .size == 0 :
502
496
return np .zeros (shape )
@@ -505,40 +499,31 @@ def draw_leaf_value(Y_mu_pred, mean, m, normal, kf, shape):
505
499
if Y_mu_pred .size == 1 :
506
500
mu_mean = np .full (shape , Y_mu_pred .item () / m )
507
501
else :
508
- mu_mean = mean (Y_mu_pred ) / m
502
+ mu_mean = fast_mean (Y_mu_pred ) / m
509
503
510
504
draw = norm + mu_mean
511
505
return draw
512
506
513
507
514
- def fast_mean ():
515
- """If available use Numba to speed up the computation of the mean."""
516
- try :
517
- from numba import jit
518
- except ImportError :
519
- from functools import partial
508
+ @njit
509
+ def fast_mean (a ):
510
+ """Use Numba to speed up the computation of the mean."""
511
+
512
+ if a .ndim == 1 :
513
+ count = a .shape [0 ]
514
+ suma = 0
515
+ for i in range (count ):
516
+ suma += a [i ]
517
+ return suma / count
518
+ elif a .ndim == 2 :
519
+ res = np .zeros (a .shape [0 ])
520
+ count = a .shape [1 ]
521
+ for j in range (a .shape [0 ]):
522
+ for i in range (count ):
523
+ res [j ] += a [j , i ]
524
+ return res / count
520
525
521
- return partial (np .mean , axis = 1 )
522
526
523
- @jit
524
- def mean (a ):
525
- if a .ndim == 1 :
526
- count = a .shape [0 ]
527
- suma = 0
528
- for i in range (count ):
529
- suma += a [i ]
530
- return suma / count
531
- elif a .ndim == 2 :
532
- res = np .zeros (a .shape [0 ])
533
- count = a .shape [1 ]
534
- for j in range (a .shape [0 ]):
535
- for i in range (count ):
536
- res [j ] += a [j , i ]
537
- return res / count
538
-
539
- return mean
540
-
541
- @jit ()
542
527
def discrete_uniform_sampler (upper_value ):
543
528
"""Draw from the uniform distribution with bounds [0, upper_value).
544
529
@@ -555,15 +540,14 @@ def __init__(self, scale, shape):
555
540
self .scale = scale
556
541
self .shape = shape
557
542
self .update ()
558
-
543
+
559
544
def random (self ):
560
545
if self .idx == self .size :
561
546
self .update ()
562
547
pop = self .cache [:, self .idx ]
563
548
self .idx += 1
564
549
return pop
565
550
566
-
567
551
def update (self ):
568
552
self .idx = 0
569
553
self .cache = np .random .normal (loc = 0.0 , scale = self .scale , size = (self .shape , self .size ))
@@ -586,44 +570,54 @@ def random(self):
586
570
self .idx += 1
587
571
return pop
588
572
589
-
590
573
def update (self ):
591
574
self .idx = 0
592
575
self .cache = np .random .uniform (
593
576
self .lower_bound , self .upper_bound , size = (self .shape , self .size )
594
577
)
595
578
596
- @jit ()
597
- def systematic (W ):
598
- """Systematic resampling.
579
+
580
+ def systematic (normalized_weights ):
599
581
"""
600
- M = len (W )
601
- su = (np .random .rand (1 ) + np .arange (M )) / M
602
- return inverse_cdf (su , W ) + 2
603
-
604
-
605
- @jit (nopython = True )
606
- def inverse_cdf (su , W ):
607
- """Inverse CDF algorithm for a finite distribution.
608
- Parameters
609
- ----------
610
- su: (M,) ndarray
611
- M sorted uniform variates (i.e. M ordered points in [0,1]).
612
- W: (N,) ndarray
613
- a vector of N normalized weights (>=0 and sum to one)
614
- Returns
615
- -------
616
- A: (M,) ndarray
617
- a vector of M indices in range 0, ..., N-1
582
+ Systematic resampling.
583
+
584
+ Return indices in the range 2, ..., len(normalized_weights)+2
585
+
586
+ Note: adapted from https://github.com/nchopin/particles
587
+ """
588
+ lnw = len (normalized_weights )
589
+ single_uniform = (np .random .rand (1 ) + np .arange (lnw )) / lnw
590
+ return inverse_cdf (single_uniform , normalized_weights ) + 2
591
+
592
+
593
+ @njit
594
+ def inverse_cdf (single_uniform , normalized_weights ):
595
+ """
596
+ Inverse CDF algorithm for a finite distribution.
597
+
598
+ Parameters
599
+ ----------
600
+ single_uniform: ndarray
601
+ ordered points in [0,1]
602
+
603
+ normalized_weights: ndarray
604
+ normalized weights
605
+
606
+ Returns
607
+ -------
608
+ A: ndarray
609
+ a vector of indices in range 2, ..., len(normalized_weights)+2
610
+
611
+ Note: adapted from https://github.com/nchopin/particles
618
612
"""
619
613
j = 0
620
- s = W [0 ]
621
- M = su .shape [0 ]
614
+ s = normalized_weights [0 ]
615
+ M = single_uniform .shape [0 ]
622
616
A = np .empty (M , dtype = np .int64 )
623
617
for n in range (M ):
624
- while su [n ] > s :
618
+ while single_uniform [n ] > s :
625
619
j += 1
626
- s += W [j ]
620
+ s += normalized_weights [j ]
627
621
A [n ] = j
628
622
return A
629
623
0 commit comments