Skip to content

Commit 99209df

Browse files
committed
fix shapes
1 parent ed7c4d0 commit 99209df

File tree

2 files changed

+10
-11
lines changed

2 files changed

+10
-11
lines changed

pymc_bart/bart.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,22 @@ class BARTRV(RandomVariable):
3737
"""Base class for BART."""
3838

3939
name: str = "BART"
40-
signature = "(n,d),(n),(),(),(),(n)->(n)"
40+
signature = "(m,n),(m),(),(),() -> (m)"
4141
dtype: str = "floatX"
4242
_print_name: Tuple[str, str] = ("BART", "\\operatorname{BART}")
4343
all_trees = List[List[List[Tree]]]
4444

4545
def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): # pylint: disable=arguments-renamed
46-
return dist_params[0].shape[:1]
46+
idx = dist_params[0].ndim - 2
47+
return [dist_params[0].shape[idx]]
4748

4849
@classmethod
4950
def rng_fn( # pylint: disable=W0237
50-
cls, rng=None, X=None, Y=None, m=None, alpha=None, beta=None, split_prior=None, size=None
51+
cls, rng=None, X=None, Y=None, m=None, alpha=None, beta=None, size=None
5152
):
5253
if not size:
5354
size = None
55+
5456
if not cls.all_trees:
5557
if size is not None:
5658
return np.full((size[0], cls.Y.shape[0]), cls.Y.mean())
@@ -97,9 +99,6 @@ class BART(Distribution):
9799
List of SplitRule objects, one per column in input data.
98100
Allows using different split rules for different columns. Default is ContinuousSplitRule.
99101
Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables.
100-
shape: : Optional[Tuple], default None
101-
Specify the output shape. If shape is different from (len(X)) (the default), train a
102-
separate tree for each value in other dimensions.
103102
separate_trees : Optional[bool], default False
104103
When training multiple trees (by setting a shape parameter), the default behavior is to
105104
learn a joint tree structure and only have different leaf values for each.

tests/test_bart.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def test_bart_vi(response):
5353
mu = pmb.BART("mu", X, Y, m=10, response=response)
5454
sigma = pm.HalfNormal("sigma", 1)
5555
y = pm.Normal("y", mu, sigma, observed=Y)
56-
idata = pm.sample(random_seed=3415)
56+
idata = pm.sample(tune=200, draws=200, random_seed=3415)
5757
var_imp = (
5858
idata.sample_stats["variable_inclusion"]
5959
.stack(samples=("chain", "draw"))
@@ -77,8 +77,8 @@ def test_missing_data(response):
7777
with pm.Model() as model:
7878
mu = pmb.BART("mu", X, Y, m=10, response=response)
7979
sigma = pm.HalfNormal("sigma", 1)
80-
y = pm.Normal("y", mu, sigma, observed=Y)
81-
idata = pm.sample(tune=100, draws=100, chains=1, random_seed=3415)
80+
pm.Normal("y", mu, sigma, observed=Y)
81+
pm.sample(tune=100, draws=100, chains=1, random_seed=3415)
8282

8383

8484
@pytest.mark.parametrize(
@@ -116,7 +116,7 @@ def test_shape(response):
116116
with pm.Model() as model:
117117
w = pmb.BART("w", X, Y, m=2, response=response, shape=(2, 250))
118118
y = pm.Normal("y", w[0], pm.math.abs(w[1]), observed=Y)
119-
idata = pm.sample(random_seed=3415)
119+
idata = pm.sample(tune=50, draws=10, random_seed=3415)
120120

121121
assert model.initial_point()["w"].shape == (2, 250)
122122
assert idata.posterior.coords["w_dim_0"].data.size == 2
@@ -133,7 +133,7 @@ class TestUtils:
133133
mu = pmb.BART("mu", X, Y, m=10)
134134
sigma = pm.HalfNormal("sigma", 1)
135135
y = pm.Normal("y", mu, sigma, observed=Y)
136-
idata = pm.sample(random_seed=3415)
136+
idata = pm.sample(tune=200, draws=200, random_seed=3415)
137137

138138
def test_sample_posterior(self):
139139
all_trees = self.mu.owner.op.all_trees

0 commit comments

Comments
 (0)