Skip to content

Commit cf14591

Browse files
authored
Merge pull request #6 from pymc-devs/resample
Use systematic resample
2 parents 55e8f05 + ecef8a7 commit cf14591

File tree

3 files changed

+103
-51
lines changed

3 files changed

+103
-51
lines changed

pymc_bart/pgbart.py

Lines changed: 95 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
import logging
1616

17-
from copy import copy
17+
from copy import deepcopy
18+
from numba import njit
1819

1920
import aesara
2021
import numpy as np
@@ -56,7 +57,7 @@ class PGBART(ArrayStepShared):
5657
def __init__(
5758
self,
5859
vars=None,
59-
num_particles=40,
60+
num_particles=20,
6061
batch="auto",
6162
model=None,
6263
):
@@ -104,8 +105,6 @@ def __init__(
104105
idx_data_points=np.arange(self.num_observations, dtype="int32"),
105106
shape=self.shape,
106107
)
107-
self.mean = fast_mean()
108-
109108
self.normal = NormalSampler(mu_std, self.shape)
110109
self.uniform = UniformSampler(0.33, 0.75, self.shape)
111110
self.prior_prob_leaf_node = compute_prior_probability(self.alpha)
@@ -158,7 +157,6 @@ def astep(self, _):
158157
self.X,
159158
self.missing_data,
160159
self.sum_trees,
161-
self.mean,
162160
self.m,
163161
self.normal,
164162
self.shape,
@@ -173,11 +171,8 @@ def astep(self, _):
173171
# Normalize weights
174172
w_t, normalized_weights = self.normalize(particles[2:])
175173

176-
# Resample all but first two particles
177-
new_indices = np.random.choice(
178-
self.indices, size=self.len_indices, p=normalized_weights
179-
)
180-
particles[2:] = particles[new_indices]
174+
# Resample
175+
particles = self.resample(particles, normalized_weights)
181176

182177
# Set the new weight
183178
for p in particles[2:]:
@@ -196,15 +191,17 @@ def astep(self, _):
196191
self.sum_trees = self.sum_trees_noi + new_tree._predict()
197192
self.all_trees[tree_id] = new_tree.trim()
198193

194+
used_variates = new_tree.get_split_variables()
195+
199196
if self.tune:
200197
self.ssv = SampleSplittingVariable(self.alpha_vec)
201-
for index in new_particle.used_variates:
198+
for index in used_variates:
202199
self.alpha_vec[index] += 1
203200
else:
204-
for index in new_particle.used_variates:
201+
for index in used_variates:
205202
variable_inclusion[index] += 1
206203

207-
stats = {"variable_inclusion": variable_inclusion, "bart_trees": copy(self.all_trees)}
204+
stats = {"variable_inclusion": variable_inclusion, "bart_trees": self.all_trees}
208205
return self.sum_trees, [stats]
209206

210207
def normalize(self, particles):
@@ -225,18 +222,36 @@ def normalize(self, particles):
225222

226223
return w_t, normalized_weights
227224

225+
def resample(self, particles, normalized_weights):
226+
"""
227+
Use systematic resample for all but first two particles
228+
229+
Ensure particles are copied only if needed.
230+
"""
231+
new_indices = systematic(normalized_weights)
232+
seen = []
233+
new_particles = []
234+
for idx in new_indices:
235+
if idx in seen:
236+
new_particles.append(deepcopy(particles[idx]))
237+
else:
238+
new_particles.append(particles[idx])
239+
seen.append(idx)
240+
241+
particles[2:] = new_particles
242+
243+
return particles
244+
228245
def init_particles(self, tree_id: int) -> np.ndarray:
229246
"""Initialize particles."""
230247
p0 = self.all_particles[tree_id]
231-
p1 = copy(p0)
248+
p1 = deepcopy(p0)
232249
p1.sample_leafs(
233250
self.sum_trees,
234-
self.mean,
235251
self.m,
236252
self.normal,
237253
self.shape,
238254
)
239-
240255
# The old tree and the one with new leafs do not grow so we update the weights only once
241256
self.update_weight(p0, old=True)
242257
self.update_weight(p1, old=True)
@@ -286,7 +301,6 @@ def __init__(self, tree):
286301
self.expansion_nodes = [0]
287302
self.log_weight = 0
288303
self.old_likelihood_logp = 0
289-
self.used_variates = []
290304
self.kf = 0.75
291305

292306
def sample_tree(
@@ -297,7 +311,6 @@ def sample_tree(
297311
X,
298312
missing_data,
299313
sum_trees,
300-
mean,
301314
m,
302315
normal,
303316
shape,
@@ -317,7 +330,6 @@ def sample_tree(
317330
X,
318331
missing_data,
319332
sum_trees,
320-
mean,
321333
m,
322334
normal,
323335
self.kf,
@@ -326,20 +338,18 @@ def sample_tree(
326338
if index_selected_predictor is not None:
327339
new_indexes = self.tree.idx_leaf_nodes[-2:]
328340
self.expansion_nodes.extend(new_indexes)
329-
self.used_variates.append(index_selected_predictor)
330341
tree_grew = True
331342

332343
return tree_grew
333344

334-
def sample_leafs(self, sum_trees, mean, m, normal, shape):
345+
def sample_leafs(self, sum_trees, m, normal, shape):
335346

336347
for idx in self.tree.idx_leaf_nodes:
337348
if idx > 0:
338349
leaf = self.tree[idx]
339350
idx_data_points = leaf.idx_data_points
340351
node_value = draw_leaf_value(
341352
sum_trees[:, idx_data_points],
342-
mean,
343353
m,
344354
normal,
345355
self.kf,
@@ -400,7 +410,6 @@ def grow_tree(
400410
X,
401411
missing_data,
402412
sum_trees,
403-
mean,
404413
m,
405414
normal,
406415
kf,
@@ -429,7 +438,6 @@ def grow_tree(
429438
idx_data_point = new_idx_data_points[idx]
430439
node_value = draw_leaf_value(
431440
sum_trees[:, idx_data_point],
432-
mean,
433441
m,
434442
normal,
435443
kf,
@@ -482,7 +490,7 @@ def get_split_value(available_splitting_values, idx_data_points, missing_data):
482490
return split_value
483491

484492

485-
def draw_leaf_value(Y_mu_pred, mean, m, normal, kf, shape):
493+
def draw_leaf_value(Y_mu_pred, m, normal, kf, shape):
486494
"""Draw Gaussian distributed leaf values."""
487495
if Y_mu_pred.size == 0:
488496
return np.zeros(shape)
@@ -491,38 +499,29 @@ def draw_leaf_value(Y_mu_pred, mean, m, normal, kf, shape):
491499
if Y_mu_pred.size == 1:
492500
mu_mean = np.full(shape, Y_mu_pred.item() / m)
493501
else:
494-
mu_mean = mean(Y_mu_pred) / m
502+
mu_mean = fast_mean(Y_mu_pred) / m
495503

496504
draw = norm + mu_mean
497505
return draw
498506

499507

500-
def fast_mean():
501-
"""If available use Numba to speed up the computation of the mean."""
502-
try:
503-
from numba import jit
504-
except ImportError:
505-
from functools import partial
506-
507-
return partial(np.mean, axis=1)
508-
509-
@jit
510-
def mean(a):
511-
if a.ndim == 1:
512-
count = a.shape[0]
513-
suma = 0
508+
@njit
509+
def fast_mean(a):
510+
"""Use Numba to speed up the computation of the mean."""
511+
512+
if a.ndim == 1:
513+
count = a.shape[0]
514+
suma = 0
515+
for i in range(count):
516+
suma += a[i]
517+
return suma / count
518+
elif a.ndim == 2:
519+
res = np.zeros(a.shape[0])
520+
count = a.shape[1]
521+
for j in range(a.shape[0]):
514522
for i in range(count):
515-
suma += a[i]
516-
return suma / count
517-
elif a.ndim == 2:
518-
res = np.zeros(a.shape[0])
519-
count = a.shape[1]
520-
for j in range(a.shape[0]):
521-
for i in range(count):
522-
res[j] += a[j, i]
523-
return res / count
524-
525-
return mean
523+
res[j] += a[j, i]
524+
return res / count
526525

527526

528527
def discrete_uniform_sampler(upper_value):
@@ -578,6 +577,51 @@ def update(self):
578577
)
579578

580579

580+
def systematic(normalized_weights):
581+
"""
582+
Systematic resampling.
583+
584+
Return indices in the range 2, ..., len(normalized_weights)+2
585+
586+
Note: adapted from https://github.com/nchopin/particles
587+
"""
588+
lnw = len(normalized_weights)
589+
single_uniform = (np.random.rand(1) + np.arange(lnw)) / lnw
590+
return inverse_cdf(single_uniform, normalized_weights) + 2
591+
592+
593+
@njit
594+
def inverse_cdf(single_uniform, normalized_weights):
595+
"""
596+
Inverse CDF algorithm for a finite distribution.
597+
598+
Parameters
599+
----------
600+
single_uniform: ndarray
601+
ordered points in [0,1]
602+
603+
normalized_weights: ndarray
604+
normalized weights
605+
606+
Returns
607+
-------
608+
A: ndarray
609+
a vector of indices in range 2, ..., len(normalized_weights)+2
610+
611+
Note: adapted from https://github.com/nchopin/particles
612+
"""
613+
j = 0
614+
s = normalized_weights[0]
615+
M = single_uniform.shape[0]
616+
A = np.empty(M, dtype=np.int64)
617+
for n in range(M):
618+
while single_uniform[n] > s:
619+
j += 1
620+
s += normalized_weights[j]
621+
A[n] = j
622+
return A
623+
624+
581625
def logp(point, out_vars, vars, shared):
582626
"""Compile Aesara function of the model and the input and output variables.
583627

pymc_bart/tree.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,13 @@ def trim(self):
8383
del current_node.idx_data_points
8484
return a_tree
8585

86+
def get_split_variables(self):
87+
return [
88+
node.idx_split_variable
89+
for node in self.tree_structure.values()
90+
if isinstance(node, SplitNode)
91+
]
92+
8693
def _predict(self):
8794
output = self.output
8895
for node_index in self.idx_leaf_nodes:

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
pymc>=4.1.7
22
arviz>=0.12.1
3+
numba>=0.55.1

0 commit comments

Comments
 (0)