Skip to content

Commit aedc84e

Browse files
authored
make random sample match variable shape (#59)
1 parent 9467104 commit aedc84e

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

pymc_bart/bart.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def rng_fn(cls, rng=None, X=None, Y=None, m=None, alpha=None, split_prior=None,
5050
else:
5151
return np.full(cls.Y.shape[0], cls.Y.mean())
5252
else:
53-
return _sample_posterior(cls.all_trees, cls.X, rng=rng).squeeze()
53+
return _sample_posterior(cls.all_trees, cls.X, rng=rng).squeeze().T
5454

5555

5656
bart = BARTRV()

tests/test_bart.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,19 @@ def test_shared_variable():
5858
assert ppc2.posterior_predictive["y"].shape == (2, 100, 3)
5959

6060

61+
def test_shape():
62+
X = np.random.normal(0, 1, size=(250, 3))
63+
Y = np.random.normal(0, 1, size=250)
64+
65+
with pm.Model() as model:
66+
w = pmb.BART("w", X, Y, m=2, shape=(2, 250))
67+
y = pm.Normal("y", w[0], pm.math.abs(w[1]), observed=Y)
68+
idata = pm.sample(random_seed=3415)
69+
70+
assert model.initial_point()["w"].shape == (2, 250)
71+
assert idata.posterior.coords["w_dim_0"].data.size == 2
72+
assert idata.posterior.coords["w_dim_1"].data.size == 250
73+
6174
class TestUtils:
6275
X_norm = np.random.normal(0, 1, size=(50, 2))
6376
X_binom = np.random.binomial(1, 0.5, size=(50, 1))

0 commit comments

Comments
 (0)