Skip to content

Commit ab0adfb

Browse files
authored
modify resampling schema and refactor (#65)
1 parent 0f77300 commit ab0adfb

File tree

4 files changed

+89
-97
lines changed

4 files changed

+89
-97
lines changed

pymc_bart/pgbart.py

Lines changed: 49 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import logging
16-
1715
from numba import njit
1816

1917
import numpy as np
@@ -30,8 +28,6 @@
3028
from pymc_bart.bart import BARTRV
3129
from pymc_bart.tree import Tree, Node, get_depth
3230

33-
_log = logging.getLogger("pymc")
34-
3531

3632
class PGBART(ArrayStepShared):
3733
"""
@@ -41,8 +37,8 @@ class PGBART(ArrayStepShared):
4137
----------
4238
vars: list
4339
List of value variables for sampler
44-
num_particles : int
45-
Number of particles for the conditional SMC sampler. Defaults to 20
40+
num_particles : tuple
41+
Number of particles. Defaults to 20
4642
batch : int or tuple
4743
Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees
4844
during tuning and after tuning. If a tuple is passed the first element is the batch size
@@ -54,7 +50,7 @@ class PGBART(ArrayStepShared):
5450
name = "pgbart"
5551
default_blocked = False
5652
generates_stats = True
57-
stats_dtypes = [{"variable_inclusion": object}]
53+
stats_dtypes = [{"variable_inclusion": object, "tune": bool}]
5854

5955
def __init__(
6056
self,
@@ -89,7 +85,7 @@ def __init__(
8985
if self.bart.split_prior:
9086
self.alpha_vec = self.bart.split_prior
9187
else:
92-
self.alpha_vec = np.ones(self.X.shape[1])
88+
self.alpha_vec = np.ones(self.X.shape[1], dtype=np.int32)
9389
init_mean = self.bart.Y.mean()
9490
# if data is binary
9591
y_unique = np.unique(self.bart.Y)
@@ -105,7 +101,7 @@ def __init__(
105101
self.sum_trees = np.full((self.shape, self.bart.Y.shape[0]), init_mean).astype(
106102
config.floatX
107103
)
108-
self.sum_trees_noi = self.sum_trees - (init_mean / self.m)
104+
self.sum_trees_noi = self.sum_trees - init_mean
109105
self.a_tree = Tree.new_tree(
110106
leaf_node_value=init_mean / self.m,
111107
idx_data_points=np.arange(self.num_observations, dtype="int32"),
@@ -130,31 +126,34 @@ def __init__(
130126
self.batch = (batch, batch)
131127

132128
self.num_particles = num_particles
133-
self.log_num_particles = np.log(num_particles)
134-
self.indices = list(range(2, num_particles))
135-
self.len_indices = len(self.indices)
136-
129+
self.indices = list(range(1, num_particles))
137130
shared = make_shared_replacements(initial_values, vars, model)
138131
self.likelihood_logp = logp(initial_values, [model.datalogp], vars, shared)
139132
self.all_particles = list(ParticleTree(self.a_tree) for _ in range(self.m))
140133
self.all_trees = np.array([p.tree for p in self.all_particles])
134+
self.lower = 0
135+
self.iter = 0
141136
super().__init__(vars, shared)
142137

143138
def astep(self, _):
144139
variable_inclusion = np.zeros(self.num_variates, dtype="int")
145140

146-
tree_ids = np.random.choice(range(self.m), replace=False, size=self.batch[~self.tune])
141+
upper = min(self.lower + self.batch[~self.tune], self.m)
142+
tree_ids = range(self.lower, upper)
143+
self.lower = upper if upper < self.m else 0
144+
147145
for tree_id in tree_ids:
146+
self.iter += 1
148147
# Compute the sum of trees without the old tree that we are attempting to replace
149148
self.sum_trees_noi = self.sum_trees - self.all_particles[tree_id].tree._predict()
150-
# Generate an initial set of SMC particles
151-
# at the end of the algorithm we return one of these particles as the new tree
149+
# Generate an initial set of particles
150+
# at the end we return one of these particles as the new tree
152151
particles = self.init_particles(tree_id)
153152

154153
while True:
155-
# Sample each particle (try to grow each tree), except for the first two
154+
# Sample each particle (try to grow each tree), except for the first one
156155
stop_growing = True
157-
for p in particles[2:]:
156+
for p in particles[1:]:
158157
tree_grew = p.sample_tree(
159158
self.ssv,
160159
self.available_predictors,
@@ -174,65 +173,55 @@ def astep(self, _):
174173
break
175174

176175
# Normalize weights
177-
w_t, normalized_weights = self.normalize(particles[2:])
176+
normalized_weights = self.normalize(particles[1:])
178177

179178
# Resample
180179
particles = self.resample(particles, normalized_weights)
181180

182-
# Set the new weight
183-
for p in particles[2:]:
184-
p.log_weight = w_t
185-
186-
for p in particles[2:]:
187-
p.log_weight = p.old_likelihood_logp
188-
189-
_, normalized_weights = self.normalize(particles)
190-
# Get the new tree and update
181+
normalized_weights = self.normalize(particles)
182+
# Get the new particle and associated tree
191183
self.all_particles[tree_id], new_tree = self.get_particle_tree(
192184
particles, normalized_weights
193185
)
186+
# Update the sum of trees
194187
self.sum_trees = self.sum_trees_noi + new_tree._predict()
188+
# To reduce memory usage, we trim the tree
195189
self.all_trees[tree_id] = new_tree.trim()
196190

197191
if self.tune:
198-
self.ssv = SampleSplittingVariable(self.alpha_vec)
192+
# Update the splitting variable and the splitting variable sampler
193+
if self.iter > self.m:
194+
self.ssv = SampleSplittingVariable(self.alpha_vec)
199195
for index in new_tree.get_split_variables():
200196
self.alpha_vec[index] += 1
201197
else:
198+
# update the variable inclusion
202199
for index in new_tree.get_split_variables():
203200
variable_inclusion[index] += 1
204201

205202
if not self.tune:
206203
self.bart.all_trees.append(self.all_trees)
207204

208-
stats = {"variable_inclusion": variable_inclusion}
205+
stats = {"variable_inclusion": variable_inclusion, "tune": self.tune}
209206
return self.sum_trees, [stats]
210207

211208
def normalize(self, particles):
212-
"""Use logsumexp trick to get w_t and softmax to get normalized_weights.
213-
214-
w_t is the un-normalized weight per particle, we will assign it to the
215-
next round of particles, so they all start with the same weight.
209+
"""
210+
Use softmax to get normalized_weights.
216211
"""
217212
log_w = np.array([p.log_weight for p in particles])
218213
log_w_max = log_w.max()
219214
log_w_ = log_w - log_w_max
220-
wei = np.exp(log_w_)
221-
w_sum = wei.sum()
222-
w_t = log_w_max + np.log(w_sum) - self.log_num_particles
223-
normalized_weights = wei / w_sum
224-
# stabilize weights to avoid assigning exactly zero probability to a particle
225-
normalized_weights += 1e-12
226-
227-
return w_t, normalized_weights
215+
wei = np.exp(log_w_) + 1e-12
216+
return wei / wei.sum()
228217

229218
def resample(self, particles, normalized_weights):
230219
"""
231-
Use systematic resample for all but first two particles
220+
Use systematic resample for all but the first particle
232221
233222
Ensure particles are copied only if needed.
234223
"""
235-
new_indices = self.systematic(normalized_weights) + 2
224+
new_indices = self.systematic(normalized_weights) + 1
236225
seen = []
237226
new_particles = []
238227
for idx in new_indices:
@@ -242,18 +231,19 @@ def resample(self, particles, normalized_weights):
242231
new_particles.append(particles[idx])
243232
seen.append(idx)
244233

245-
particles[2:] = new_particles
234+
particles[1:] = new_particles
246235

247236
return particles
248237

249238
def get_particle_tree(self, particles, normalized_weights):
250239
"""
251-
Sample a new particle, new tree and update log_weight
240+
Sample a new particle and associated tree
252241
"""
253242
new_index = self.systematic(normalized_weights)[
254243
discrete_uniform_sampler(self.num_particles)
255244
]
256245
new_particle = particles[new_index]
246+
257247
return new_particle, new_particle.tree
258248

259249
def systematic(self, normalized_weights):
@@ -265,47 +255,31 @@ def systematic(self, normalized_weights):
265255
Note: adapted from https://github.com/nchopin/particles
266256
"""
267257
lnw = len(normalized_weights)
268-
single_uniform = (self.uniform.random() + np.arange(lnw)) / lnw
258+
single_uniform = (self.uniform.rvs() + np.arange(lnw)) / lnw
269259
return inverse_cdf(single_uniform, normalized_weights)
270260

271261
def init_particles(self, tree_id: int) -> np.ndarray:
272262
"""Initialize particles."""
273263
p0 = self.all_particles[tree_id]
274-
p1 = p0.copy()
275-
p1.sample_leafs(
276-
self.sum_trees,
277-
self.m,
278-
self.normal,
279-
self.shape,
280-
)
281-
# The old tree and the one with new leafs do not grow so we update the weights only once
282-
self.update_weight(p0, old=True)
283-
self.update_weight(p1, old=True)
284-
particles = [p0, p1]
264+
# The old tree does not grow so we update the weight only once
265+
self.update_weight(p0)
266+
particles = [p0]
285267

286268
for _ in self.indices:
287269
particles.append(
288-
ParticleTree(self.a_tree, self.uniform_kf.random() if self.tune else p0.kfactor)
270+
ParticleTree(self.a_tree, self.uniform_kf.rvs() if self.tune else p0.kfactor)
289271
)
290272

291-
return np.array(particles)
273+
return particles
292274

293-
def update_weight(self, particle, old=False):
275+
def update_weight(self, particle):
294276
"""
295277
Update the weight of a particle.
296-
297-
Since the prior is used as the proposal,the weights are updated additively as the ratio of
298-
the new and old log-likelihoods.
299278
"""
300279
new_likelihood = self.likelihood_logp(
301280
(self.sum_trees_noi + particle.tree._predict()).flatten()
302281
)
303-
if old:
304-
particle.log_weight = new_likelihood
305-
else:
306-
particle.log_weight += new_likelihood - particle.old_likelihood_logp
307-
308-
particle.old_likelihood_logp = new_likelihood
282+
particle.log_weight = new_likelihood
309283

310284
@staticmethod
311285
def competence(var, has_grad):
@@ -319,20 +293,17 @@ def competence(var, has_grad):
319293
class ParticleTree:
320294
"""Particle tree."""
321295

322-
__slots__ = "tree", "expansion_nodes", "log_weight", "old_likelihood_logp", "kfactor"
296+
__slots__ = "tree", "expansion_nodes", "log_weight", "kfactor"
323297

324298
def __init__(self, tree, kfactor=0.75):
325299
self.tree = tree.copy()
326300
self.expansion_nodes = [0]
327301
self.log_weight = 0
328-
self.old_likelihood_logp = 0
329302
self.kfactor = kfactor
330303

331304
def copy(self):
332305
p = ParticleTree(self.tree)
333306
p.expansion_nodes = self.expansion_nodes.copy()
334-
p.log_weight = self.log_weight
335-
p.old_likelihood_logp = self.old_likelihood_logp
336307
p.kfactor = self.kfactor
337308
return p
338309

@@ -374,20 +345,6 @@ def sample_tree(
374345

375346
return tree_grew
376347

377-
def sample_leafs(self, sum_trees, m, normal, shape):
378-
379-
for idx in self.tree.idx_leaf_nodes:
380-
if idx > 0:
381-
leaf = self.tree[idx]
382-
idx_data_points = leaf.idx_data_points
383-
node_value = draw_leaf_value(
384-
sum_trees[:, idx_data_points],
385-
m,
386-
normal.random() * self.kfactor,
387-
shape,
388-
)
389-
leaf.value = node_value
390-
391348

392349
class SampleSplittingVariable:
393350
def __init__(self, alpha_vec):
@@ -471,7 +428,7 @@ def grow_tree(
471428
node_value = draw_leaf_value(
472429
sum_trees[:, idx_data_point],
473430
m,
474-
normal.random() * kfactor,
431+
normal.rvs() * kfactor,
475432
shape,
476433
)
477434

@@ -560,7 +517,7 @@ def __init__(self, scale, shape):
560517
self.shape = shape
561518
self.update()
562519

563-
def random(self):
520+
def rvs(self):
564521
if self.idx == self.size:
565522
self.update()
566523
pop = self.cache[:, self.idx]
@@ -582,7 +539,7 @@ def __init__(self, lower_bound, upper_bound, shape=None):
582539
self.shape = shape
583540
self.update()
584541

585-
def random(self):
542+
def rvs(self):
586543
if self.idx == self.size:
587544
self.update()
588545
if self.shape is None:
@@ -618,7 +575,7 @@ def inverse_cdf(single_uniform, normalized_weights):
618575
Returns
619576
-------
620577
new_indices: ndarray
621-
a vector of indices in range 2, ..., len(normalized_weights)+2
578+
a vector of indices in range 0, ..., len(normalized_weights)
622579
623580
Note: adapted from https://github.com/nchopin/particles
624581
"""

pymc_bart/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ def plot_variable_importance(
401401

402402
axes[1].errorbar(ticks, ev_mean, np.array((ev_mean - ev_hdi[:, 0], ev_hdi[:, 1] - ev_mean)))
403403

404+
axes[1].axhline(ev_mean[-1], ls="--", color="0.5")
404405
axes[1].set_xticks(ticks)
405406
axes[1].set_xticklabels(ticks + 1)
406407
axes[1].set_xlabel("number of covariables")

tests/test_bart.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,43 @@
33
import pytest
44
from numpy.random import RandomState
55
from numpy.testing import assert_almost_equal, assert_array_equal
6-
from pymc.tests.distributions.util import assert_moment_is_expected
76

87
import pymc_bart as pmb
98

9+
from pymc.logprob.joint_logprob import joint_logp
10+
from pymc.initial_point import make_initial_point_fn
11+
12+
13+
def assert_moment_is_expected(model, expected, check_finite_logp=True):
14+
fn = make_initial_point_fn(
15+
model=model,
16+
return_transformed=False,
17+
default_strategy="moment",
18+
)
19+
moment = fn(0)["x"]
20+
expected = np.asarray(expected)
21+
try:
22+
random_draw = model["x"].eval()
23+
except NotImplementedError:
24+
random_draw = moment
25+
26+
assert moment.shape == expected.shape
27+
assert expected.shape == random_draw.shape
28+
assert np.allclose(moment, expected)
29+
30+
if check_finite_logp:
31+
logp_moment = (
32+
joint_logp(
33+
(model["x"],),
34+
rvs_to_values={model["x"]: pm.math.constant(moment)},
35+
rvs_to_transforms={},
36+
rvs_to_total_sizes={},
37+
)[0]
38+
.sum()
39+
.eval()
40+
)
41+
assert np.isfinite(logp_moment)
42+
1043

1144
def test_bart_vi():
1245
X = np.random.normal(0, 1, size=(250, 3))
@@ -71,6 +104,7 @@ def test_shape():
71104
assert idata.posterior.coords["w_dim_0"].data.size == 2
72105
assert idata.posterior.coords["w_dim_1"].data.size == 250
73106

107+
74108
class TestUtils:
75109
X_norm = np.random.normal(0, 1, size=(50, 2))
76110
X_binom = np.random.binomial(1, 0.5, size=(50, 1))

0 commit comments

Comments
 (0)