Skip to content

Commit 3a2194d

Browse files
committed
add missing tests
1 parent db814cb commit 3a2194d

File tree

1 file changed

+107
-0
lines changed

1 file changed

+107
-0
lines changed

tests/test_utils.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import numpy as np
2+
import pymc as pm
3+
import pytest
4+
from numpy.testing import assert_almost_equal, assert_array_equal
5+
6+
import pymc_bart as pmb
7+
8+
9+
class TestUtils:
10+
X_norm = np.random.normal(0, 1, size=(50, 2))
11+
X_binom = np.random.binomial(1, 0.5, size=(50, 1))
12+
X = np.hstack([X_norm, X_binom])
13+
Y = np.random.normal(0, 1, size=50)
14+
15+
with pm.Model() as model:
16+
mu = pmb.BART("mu", X, Y, m=10)
17+
sigma = pm.HalfNormal("sigma", 1)
18+
y = pm.Normal("y", mu, sigma, observed=Y)
19+
idata = pm.sample(tune=200, draws=200, random_seed=3415)
20+
21+
def test_sample_posterior(self):
22+
all_trees = self.mu.owner.op.all_trees
23+
rng = np.random.default_rng(3)
24+
pred_all = pmb.utils._sample_posterior(all_trees, X=self.X, rng=rng, size=2)
25+
rng = np.random.default_rng(3)
26+
pred_first = pmb.utils._sample_posterior(all_trees, X=self.X[:10], rng=rng)
27+
28+
assert_almost_equal(pred_first[0], pred_all[0, :10], decimal=4)
29+
assert pred_all.shape == (2, 50, 1)
30+
assert pred_first.shape == (1, 10, 1)
31+
32+
@pytest.mark.parametrize(
33+
"kwargs",
34+
[
35+
{},
36+
{
37+
"samples": 2,
38+
"var_discrete": [3],
39+
},
40+
{"instances": 2},
41+
{"var_idx": [0], "smooth": False, "color": "k"},
42+
{"grid": (1, 2), "sharey": "none", "alpha": 1},
43+
{"var_discrete": [0]},
44+
],
45+
)
46+
def test_ice(self, kwargs):
47+
pmb.plot_ice(self.mu, X=self.X, Y=self.Y, **kwargs)
48+
49+
@pytest.mark.parametrize(
50+
"kwargs",
51+
[
52+
{},
53+
{
54+
"samples": 2,
55+
"xs_interval": "quantiles",
56+
"xs_values": [0.25, 0.5, 0.75],
57+
"var_discrete": [3],
58+
},
59+
{"var_idx": [0], "smooth": False, "color": "k"},
60+
{"grid": (1, 2), "sharey": "none", "alpha": 1},
61+
{"var_discrete": [0]},
62+
],
63+
)
64+
def test_pdp(self, kwargs):
65+
pmb.plot_pdp(self.mu, X=self.X, Y=self.Y, **kwargs)
66+
67+
@pytest.mark.parametrize(
68+
"kwargs",
69+
[
70+
{"samples": 50},
71+
{"labels": ["A", "B", "C"], "samples": 2, "figsize": (6, 6)},
72+
],
73+
)
74+
def test_vi(self, kwargs):
75+
samples = kwargs.pop("samples")
76+
vi_results = pmb.compute_variable_importance(
77+
self.idata, bartrv=self.mu, X=self.X, samples=samples
78+
)
79+
pmb.plot_variable_importance(vi_results, **kwargs)
80+
pmb.plot_scatter_submodels(vi_results, **kwargs)
81+
82+
def test_pdp_pandas_labels(self):
83+
pd = pytest.importorskip("pandas")
84+
85+
X_names = ["norm1", "norm2", "binom"]
86+
X_pd = pd.DataFrame(self.X, columns=X_names)
87+
Y_pd = pd.Series(self.Y, name="response")
88+
axes = pmb.plot_pdp(self.mu, X=X_pd, Y=Y_pd)
89+
90+
figure = axes[0].figure
91+
assert figure.texts[0].get_text() == "Partial response"
92+
assert_array_equal([ax.get_xlabel() for ax in axes], X_names)
93+
94+
95+
def test_encoder_decoder():
96+
"""Test that the encoder-decoder works correctly."""
97+
test_cases = [
98+
np.zeros(3, dtype=int),
99+
np.ones(10, dtype=int),
100+
np.array([4, 0, 1, 0, 2, 0, 3, 0, 0, 0]),
101+
np.array([100, 50, 0, 1]),
102+
np.array([1, 2, 4, 8, 16]),
103+
]
104+
for case in test_cases:
105+
encoded = pmb.utils._encode_vi(case)
106+
decoded = pmb.utils._decode_vi(encoded, len(case))
107+
assert np.array_equal(decoded, case)

0 commit comments

Comments
 (0)