Skip to content

Commit 0f77300

Browse files
authored
add tests for individual functions/methods in PGBART (#64)
* add test pgbart, small refactor systematic resampling * remove comment
1 parent c88edea commit 0f77300

File tree

2 files changed

+73
-4
lines changed

2 files changed

+73
-4
lines changed

pymc_bart/pgbart.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def resample(self, particles, normalized_weights):
232232
233233
Ensure particles are copied only if needed.
234234
"""
235-
new_indices = self.systematic(normalized_weights)
235+
new_indices = self.systematic(normalized_weights) + 2
236236
seen = []
237237
new_particles = []
238238
for idx in new_indices:
@@ -253,20 +253,20 @@ def get_particle_tree(self, particles, normalized_weights):
253253
new_index = self.systematic(normalized_weights)[
254254
discrete_uniform_sampler(self.num_particles)
255255
]
256-
new_particle = particles[new_index - 2]
256+
new_particle = particles[new_index]
257257
return new_particle, new_particle.tree
258258

259259
def systematic(self, normalized_weights):
260260
"""
261261
Systematic resampling.
262262
263-
Return indices in the range 2, ..., len(normalized_weights)+2
263+
Return indices in the range 0, ..., len(normalized_weights)
264264
265265
Note: adapted from https://github.com/nchopin/particles
266266
"""
267267
lnw = len(normalized_weights)
268268
single_uniform = (self.uniform.random() + np.arange(lnw)) / lnw
269-
return inverse_cdf(single_uniform, normalized_weights) + 2
269+
return inverse_cdf(single_uniform, normalized_weights)
270270

271271
def init_particles(self, tree_id: int) -> np.ndarray:
272272
"""Initialize particles."""

tests/test_pgbart.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from unittest import TestCase
2+
import numpy as np
3+
import pymc as pm
4+
import pymc_bart as pmb
5+
from pymc_bart.pgbart import fast_mean, discrete_uniform_sampler, NormalSampler, UniformSampler
6+
7+
8+
class TestSystematic(TestCase):
9+
def test_systematic(self):
10+
X = np.random.normal(0, 1, size=(250, 3))
11+
Y = np.random.normal(0, 1, size=250)
12+
X[:, 0] = np.random.normal(Y, 0.1)
13+
14+
with pm.Model() as model:
15+
mu = pmb.BART("mu", X, Y, m=10)
16+
sigma = pm.HalfNormal("sigma", 1)
17+
y = pm.Normal("y", mu, sigma, observed=Y)
18+
step = pmb.PGBART([mu])
19+
20+
normalized_weights = np.array([0.5, 0.3, 0.2])
21+
indices = step.systematic(normalized_weights)
22+
23+
self.assertEqual(len(indices), len(normalized_weights))
24+
self.assertEqual(indices.dtype, np.int)
25+
self.assertTrue(all(i >= 0 and i < len(normalized_weights) for i in indices))
26+
27+
normalized_weights = np.array([0, 0.25, 0.75])
28+
indices = step.systematic(normalized_weights)
29+
self.assertTrue(all(i >= 1 and i < len(normalized_weights) for i in indices))
30+
31+
32+
def test_fast_mean():
33+
values = np.random.uniform(size=10)
34+
np.testing.assert_almost_equal(fast_mean(values), np.mean(values))
35+
36+
values = np.random.uniform(size=(2, 10))
37+
np.testing.assert_array_almost_equal(fast_mean(values), np.mean(values, 1))
38+
39+
40+
def test_discrete_uniform():
41+
sample = discrete_uniform_sampler(7)
42+
assert isinstance(sample, int)
43+
samples = np.array([discrete_uniform_sampler(7) for i in range(1000)])
44+
assert all(samples >= 0)
45+
assert all(samples < 7)
46+
47+
48+
def test_normal_sampler():
49+
normal = NormalSampler(2, shape=1)
50+
samples = np.array([normal.random() for i in range(100000)])
51+
np.testing.assert_almost_equal(samples.mean(), 0, decimal=2)
52+
np.testing.assert_almost_equal(samples.std(), 2, decimal=2)
53+
54+
normal = NormalSampler(2, shape=2)
55+
samples = np.array([normal.random() for i in range(100000)])
56+
np.testing.assert_almost_equal(samples.mean(0), [0, 0], decimal=2)
57+
np.testing.assert_almost_equal(samples.std(0), [2, 2], decimal=2)
58+
59+
60+
def test_uniform_sampler():
61+
uniform = UniformSampler(0.5, 2, shape=1)
62+
samples = np.array([uniform.random() for i in range(100000)])
63+
np.testing.assert_almost_equal(samples.mean(), 1.25, decimal=2)
64+
np.testing.assert_almost_equal(samples.std(), 0.43, decimal=2)
65+
66+
uniform = UniformSampler(0.5, 2, shape=2)
67+
samples = np.array([uniform.random() for i in range(100000)])
68+
np.testing.assert_almost_equal(samples.mean(0), [1.25, 1.25], decimal=2)
69+
np.testing.assert_almost_equal(samples.std(0), [0.43, 0.43], decimal=2)

0 commit comments

Comments
 (0)