Skip to content

Commit 2047f90

Browse files
committed
conform to recent changes in pymc
1 parent 1741d7d commit 2047f90

File tree

4 files changed

+13
-9
lines changed

4 files changed

+13
-9
lines changed

pymc_bart/bart.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@ class BARTRV(RandomVariable):
3737
"""Base class for BART."""
3838

3939
name: str = "BART"
40-
ndim_supp = 1
41-
ndims_params: List[int] = [2, 1, 0, 0, 0, 1]
40+
signature = "(n,d),(n),(),(),(),(n)->(n)"
4241
dtype: str = "floatX"
4342
_print_name: Tuple[str, str] = ("BART", "\\operatorname{BART}")
4443
all_trees = List[List[List[Tree]]]
@@ -50,6 +49,8 @@ def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=Non
5049
def rng_fn( # pylint: disable=W0237
5150
cls, rng=None, X=None, Y=None, m=None, alpha=None, beta=None, split_prior=None, size=None
5251
):
52+
if not size:
53+
size = None
5354
if not cls.all_trees:
5455
if size is not None:
5556
return np.full((size[0], cls.Y.shape[0]), cls.Y.mean())

pymc_bart/pgbart.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,10 @@ class PGBART(ArrayStepShared):
114114
name = "pgbart"
115115
default_blocked = False
116116
generates_stats = True
117-
stats_dtypes = [{"variable_inclusion": object, "tune": bool}]
117+
stats_dtypes_shapes: dict[str, tuple[type, list]] = {
118+
"variable_inclusion": (object, []),
119+
"tune": (bool, []),
120+
}
118121

119122
def __init__( # noqa: PLR0915
120123
self,
@@ -227,7 +230,7 @@ def __init__( # noqa: PLR0915
227230
def astep(self, _):
228231
variable_inclusion = np.zeros(self.num_variates, dtype="int")
229232

230-
upper = min(self.lower + self.batch[~self.tune], self.m)
233+
upper = min(self.lower + self.batch[not self.tune], self.m)
231234
tree_ids = range(self.lower, upper)
232235
self.lower = upper if upper < self.m else 0
233236

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
pymc<=5.16.2
1+
pymc>=5.16.2, <=5.18
22
arviz>=0.18.0
33
numba
44
matplotlib

tests/test_bart.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
from numpy.testing import assert_almost_equal, assert_array_equal
55
from pymc.initial_point import make_initial_point_fn
6-
from pymc.logprob.basic import joint_logp
6+
from pymc.logprob.basic import transformed_conditional_logp
77

88
import pymc_bart as pmb
99

@@ -12,7 +12,7 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True):
1212
fn = make_initial_point_fn(
1313
model=model,
1414
return_transformed=False,
15-
default_strategy="moment",
15+
default_strategy="support_point",
1616
)
1717
moment = fn(0)["x"]
1818
expected = np.asarray(expected)
@@ -27,7 +27,7 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True):
2727

2828
if check_finite_logp:
2929
logp_moment = (
30-
joint_logp(
30+
transformed_conditional_logp(
3131
(model["x"],),
3232
rvs_to_values={model["x"]: pm.math.constant(moment)},
3333
rvs_to_transforms={},
@@ -91,7 +91,7 @@ def test_shared_variable(response):
9191
Y = np.random.normal(0, 1, size=50)
9292

9393
with pm.Model() as model:
94-
data_X = pm.MutableData("data_X", X)
94+
data_X = pm.Data("data_X", X)
9595
mu = pmb.BART("mu", data_X, Y, m=2, response=response)
9696
sigma = pm.HalfNormal("sigma", 1)
9797
y = pm.Normal("y", mu, sigma, observed=Y, shape=mu.shape)

0 commit comments

Comments
 (0)