12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- import logging
16
-
17
15
from numba import njit
18
16
19
17
import numpy as np
30
28
from pymc_bart .bart import BARTRV
31
29
from pymc_bart .tree import Tree , Node , get_depth
32
30
33
- _log = logging .getLogger ("pymc" )
34
-
35
31
36
32
class PGBART (ArrayStepShared ):
37
33
"""
@@ -41,8 +37,8 @@ class PGBART(ArrayStepShared):
41
37
----------
42
38
vars: list
43
39
List of value variables for sampler
44
- num_particles : int
45
- Number of particles for the conditional SMC sampler . Defaults to 20
40
+ num_particles : tuple
41
+ Number of particles. Defaults to 20
46
42
batch : int or tuple
47
43
Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees
48
44
during tuning and after tuning. If a tuple is passed the first element is the batch size
@@ -54,7 +50,7 @@ class PGBART(ArrayStepShared):
54
50
name = "pgbart"
55
51
default_blocked = False
56
52
generates_stats = True
57
- stats_dtypes = [{"variable_inclusion" : object }]
53
+ stats_dtypes = [{"variable_inclusion" : object , "tune" : bool }]
58
54
59
55
def __init__ (
60
56
self ,
@@ -89,7 +85,7 @@ def __init__(
89
85
if self .bart .split_prior :
90
86
self .alpha_vec = self .bart .split_prior
91
87
else :
92
- self .alpha_vec = np .ones (self .X .shape [1 ])
88
+ self .alpha_vec = np .ones (self .X .shape [1 ], dtype = np . int32 )
93
89
init_mean = self .bart .Y .mean ()
94
90
# if data is binary
95
91
y_unique = np .unique (self .bart .Y )
@@ -105,7 +101,7 @@ def __init__(
105
101
self .sum_trees = np .full ((self .shape , self .bart .Y .shape [0 ]), init_mean ).astype (
106
102
config .floatX
107
103
)
108
- self .sum_trees_noi = self .sum_trees - ( init_mean / self . m )
104
+ self .sum_trees_noi = self .sum_trees - init_mean
109
105
self .a_tree = Tree .new_tree (
110
106
leaf_node_value = init_mean / self .m ,
111
107
idx_data_points = np .arange (self .num_observations , dtype = "int32" ),
@@ -130,31 +126,34 @@ def __init__(
130
126
self .batch = (batch , batch )
131
127
132
128
self .num_particles = num_particles
133
- self .log_num_particles = np .log (num_particles )
134
- self .indices = list (range (2 , num_particles ))
135
- self .len_indices = len (self .indices )
136
-
129
+ self .indices = list (range (1 , num_particles ))
137
130
shared = make_shared_replacements (initial_values , vars , model )
138
131
self .likelihood_logp = logp (initial_values , [model .datalogp ], vars , shared )
139
132
self .all_particles = list (ParticleTree (self .a_tree ) for _ in range (self .m ))
140
133
self .all_trees = np .array ([p .tree for p in self .all_particles ])
134
+ self .lower = 0
135
+ self .iter = 0
141
136
super ().__init__ (vars , shared )
142
137
143
138
def astep (self , _ ):
144
139
variable_inclusion = np .zeros (self .num_variates , dtype = "int" )
145
140
146
- tree_ids = np .random .choice (range (self .m ), replace = False , size = self .batch [~ self .tune ])
141
+ upper = min (self .lower + self .batch [~ self .tune ], self .m )
142
+ tree_ids = range (self .lower , upper )
143
+ self .lower = upper if upper < self .m else 0
144
+
147
145
for tree_id in tree_ids :
146
+ self .iter += 1
148
147
# Compute the sum of trees without the old tree that we are attempting to replace
149
148
self .sum_trees_noi = self .sum_trees - self .all_particles [tree_id ].tree ._predict ()
150
- # Generate an initial set of SMC particles
151
- # at the end of the algorithm we return one of these particles as the new tree
149
+ # Generate an initial set of particles
150
+ # at the end we return one of these particles as the new tree
152
151
particles = self .init_particles (tree_id )
153
152
154
153
while True :
155
- # Sample each particle (try to grow each tree), except for the first two
154
+ # Sample each particle (try to grow each tree), except for the first one
156
155
stop_growing = True
157
- for p in particles [2 :]:
156
+ for p in particles [1 :]:
158
157
tree_grew = p .sample_tree (
159
158
self .ssv ,
160
159
self .available_predictors ,
@@ -174,65 +173,55 @@ def astep(self, _):
174
173
break
175
174
176
175
# Normalize weights
177
- w_t , normalized_weights = self .normalize (particles [2 :])
176
+ normalized_weights = self .normalize (particles [1 :])
178
177
179
178
# Resample
180
179
particles = self .resample (particles , normalized_weights )
181
180
182
- # Set the new weight
183
- for p in particles [2 :]:
184
- p .log_weight = w_t
185
-
186
- for p in particles [2 :]:
187
- p .log_weight = p .old_likelihood_logp
188
-
189
- _ , normalized_weights = self .normalize (particles )
190
- # Get the new tree and update
181
+ normalized_weights = self .normalize (particles )
182
+ # Get the new particle and associated tree
191
183
self .all_particles [tree_id ], new_tree = self .get_particle_tree (
192
184
particles , normalized_weights
193
185
)
186
+ # Update the sum of trees
194
187
self .sum_trees = self .sum_trees_noi + new_tree ._predict ()
188
+ # To reduce memory usage, we trim the tree
195
189
self .all_trees [tree_id ] = new_tree .trim ()
196
190
197
191
if self .tune :
198
- self .ssv = SampleSplittingVariable (self .alpha_vec )
192
+ # Update the splitting variable and the splitting variable sampler
193
+ if self .iter > self .m :
194
+ self .ssv = SampleSplittingVariable (self .alpha_vec )
199
195
for index in new_tree .get_split_variables ():
200
196
self .alpha_vec [index ] += 1
201
197
else :
198
+ # update the variable inclusion
202
199
for index in new_tree .get_split_variables ():
203
200
variable_inclusion [index ] += 1
204
201
205
202
if not self .tune :
206
203
self .bart .all_trees .append (self .all_trees )
207
204
208
- stats = {"variable_inclusion" : variable_inclusion }
205
+ stats = {"variable_inclusion" : variable_inclusion , "tune" : self . tune }
209
206
return self .sum_trees , [stats ]
210
207
211
208
def normalize (self , particles ):
212
- """Use logsumexp trick to get w_t and softmax to get normalized_weights.
213
-
214
- w_t is the un-normalized weight per particle, we will assign it to the
215
- next round of particles, so they all start with the same weight.
209
+ """
210
+ Use softmax to get normalized_weights.
216
211
"""
217
212
log_w = np .array ([p .log_weight for p in particles ])
218
213
log_w_max = log_w .max ()
219
214
log_w_ = log_w - log_w_max
220
- wei = np .exp (log_w_ )
221
- w_sum = wei .sum ()
222
- w_t = log_w_max + np .log (w_sum ) - self .log_num_particles
223
- normalized_weights = wei / w_sum
224
- # stabilize weights to avoid assigning exactly zero probability to a particle
225
- normalized_weights += 1e-12
226
-
227
- return w_t , normalized_weights
215
+ wei = np .exp (log_w_ ) + 1e-12
216
+ return wei / wei .sum ()
228
217
229
218
def resample (self , particles , normalized_weights ):
230
219
"""
231
- Use systematic resample for all but first two particles
220
+ Use systematic resample for all but the first particle
232
221
233
222
Ensure particles are copied only if needed.
234
223
"""
235
- new_indices = self .systematic (normalized_weights ) + 2
224
+ new_indices = self .systematic (normalized_weights ) + 1
236
225
seen = []
237
226
new_particles = []
238
227
for idx in new_indices :
@@ -242,18 +231,19 @@ def resample(self, particles, normalized_weights):
242
231
new_particles .append (particles [idx ])
243
232
seen .append (idx )
244
233
245
- particles [2 :] = new_particles
234
+ particles [1 :] = new_particles
246
235
247
236
return particles
248
237
249
238
def get_particle_tree (self , particles , normalized_weights ):
250
239
"""
251
- Sample a new particle, new tree and update log_weight
240
+ Sample a new particle and associated tree
252
241
"""
253
242
new_index = self .systematic (normalized_weights )[
254
243
discrete_uniform_sampler (self .num_particles )
255
244
]
256
245
new_particle = particles [new_index ]
246
+
257
247
return new_particle , new_particle .tree
258
248
259
249
def systematic (self , normalized_weights ):
@@ -265,47 +255,31 @@ def systematic(self, normalized_weights):
265
255
Note: adapted from https://github.com/nchopin/particles
266
256
"""
267
257
lnw = len (normalized_weights )
268
- single_uniform = (self .uniform .random () + np .arange (lnw )) / lnw
258
+ single_uniform = (self .uniform .rvs () + np .arange (lnw )) / lnw
269
259
return inverse_cdf (single_uniform , normalized_weights )
270
260
271
261
def init_particles (self , tree_id : int ) -> np .ndarray :
272
262
"""Initialize particles."""
273
263
p0 = self .all_particles [tree_id ]
274
- p1 = p0 .copy ()
275
- p1 .sample_leafs (
276
- self .sum_trees ,
277
- self .m ,
278
- self .normal ,
279
- self .shape ,
280
- )
281
- # The old tree and the one with new leafs do not grow so we update the weights only once
282
- self .update_weight (p0 , old = True )
283
- self .update_weight (p1 , old = True )
284
- particles = [p0 , p1 ]
264
+ # The old tree does not grow so we update the weight only once
265
+ self .update_weight (p0 )
266
+ particles = [p0 ]
285
267
286
268
for _ in self .indices :
287
269
particles .append (
288
- ParticleTree (self .a_tree , self .uniform_kf .random () if self .tune else p0 .kfactor )
270
+ ParticleTree (self .a_tree , self .uniform_kf .rvs () if self .tune else p0 .kfactor )
289
271
)
290
272
291
- return np . array ( particles )
273
+ return particles
292
274
293
- def update_weight (self , particle , old = False ):
275
+ def update_weight (self , particle ):
294
276
"""
295
277
Update the weight of a particle.
296
-
297
- Since the prior is used as the proposal,the weights are updated additively as the ratio of
298
- the new and old log-likelihoods.
299
278
"""
300
279
new_likelihood = self .likelihood_logp (
301
280
(self .sum_trees_noi + particle .tree ._predict ()).flatten ()
302
281
)
303
- if old :
304
- particle .log_weight = new_likelihood
305
- else :
306
- particle .log_weight += new_likelihood - particle .old_likelihood_logp
307
-
308
- particle .old_likelihood_logp = new_likelihood
282
+ particle .log_weight = new_likelihood
309
283
310
284
@staticmethod
311
285
def competence (var , has_grad ):
@@ -319,20 +293,17 @@ def competence(var, has_grad):
319
293
class ParticleTree :
320
294
"""Particle tree."""
321
295
322
- __slots__ = "tree" , "expansion_nodes" , "log_weight" , "old_likelihood_logp" , " kfactor"
296
+ __slots__ = "tree" , "expansion_nodes" , "log_weight" , "kfactor"
323
297
324
298
def __init__ (self , tree , kfactor = 0.75 ):
325
299
self .tree = tree .copy ()
326
300
self .expansion_nodes = [0 ]
327
301
self .log_weight = 0
328
- self .old_likelihood_logp = 0
329
302
self .kfactor = kfactor
330
303
331
304
def copy (self ):
332
305
p = ParticleTree (self .tree )
333
306
p .expansion_nodes = self .expansion_nodes .copy ()
334
- p .log_weight = self .log_weight
335
- p .old_likelihood_logp = self .old_likelihood_logp
336
307
p .kfactor = self .kfactor
337
308
return p
338
309
@@ -374,20 +345,6 @@ def sample_tree(
374
345
375
346
return tree_grew
376
347
377
- def sample_leafs (self , sum_trees , m , normal , shape ):
378
-
379
- for idx in self .tree .idx_leaf_nodes :
380
- if idx > 0 :
381
- leaf = self .tree [idx ]
382
- idx_data_points = leaf .idx_data_points
383
- node_value = draw_leaf_value (
384
- sum_trees [:, idx_data_points ],
385
- m ,
386
- normal .random () * self .kfactor ,
387
- shape ,
388
- )
389
- leaf .value = node_value
390
-
391
348
392
349
class SampleSplittingVariable :
393
350
def __init__ (self , alpha_vec ):
@@ -471,7 +428,7 @@ def grow_tree(
471
428
node_value = draw_leaf_value (
472
429
sum_trees [:, idx_data_point ],
473
430
m ,
474
- normal .random () * kfactor ,
431
+ normal .rvs () * kfactor ,
475
432
shape ,
476
433
)
477
434
@@ -560,7 +517,7 @@ def __init__(self, scale, shape):
560
517
self .shape = shape
561
518
self .update ()
562
519
563
- def random (self ):
520
+ def rvs (self ):
564
521
if self .idx == self .size :
565
522
self .update ()
566
523
pop = self .cache [:, self .idx ]
@@ -582,7 +539,7 @@ def __init__(self, lower_bound, upper_bound, shape=None):
582
539
self .shape = shape
583
540
self .update ()
584
541
585
- def random (self ):
542
+ def rvs (self ):
586
543
if self .idx == self .size :
587
544
self .update ()
588
545
if self .shape is None :
@@ -618,7 +575,7 @@ def inverse_cdf(single_uniform, normalized_weights):
618
575
Returns
619
576
-------
620
577
new_indices: ndarray
621
- a vector of indices in range 2 , ..., len(normalized_weights)+2
578
+ a vector of indices in range 0 , ..., len(normalized_weights)
622
579
623
580
Note: adapted from https://github.com/nchopin/particles
624
581
"""
0 commit comments