From 2047f90d4d97eaedd8e4e07b4c35d143cc1a67de Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Thu, 7 Nov 2024 13:21:33 -0300 Subject: [PATCH 1/3] conform to recent changes in pymc --- pymc_bart/bart.py | 5 +++-- pymc_bart/pgbart.py | 7 +++++-- requirements.txt | 2 +- tests/test_bart.py | 8 ++++---- 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index 969baf4..1f21280 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -37,8 +37,7 @@ class BARTRV(RandomVariable): """Base class for BART.""" name: str = "BART" - ndim_supp = 1 - ndims_params: List[int] = [2, 1, 0, 0, 0, 1] + signature = "(n,d),(n),(),(),(),(n)->(n)" dtype: str = "floatX" _print_name: Tuple[str, str] = ("BART", "\\operatorname{BART}") all_trees = List[List[List[Tree]]] @@ -50,6 +49,8 @@ def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=Non def rng_fn( # pylint: disable=W0237 cls, rng=None, X=None, Y=None, m=None, alpha=None, beta=None, split_prior=None, size=None ): + if not size: + size = None if not cls.all_trees: if size is not None: return np.full((size[0], cls.Y.shape[0]), cls.Y.mean()) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 91a9beb..6de7a53 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -114,7 +114,10 @@ class PGBART(ArrayStepShared): name = "pgbart" default_blocked = False generates_stats = True - stats_dtypes = [{"variable_inclusion": object, "tune": bool}] + stats_dtypes_shapes: dict[str, tuple[type, list]] = { + "variable_inclusion": (object, []), + "tune": (bool, []), + } def __init__( # noqa: PLR0915 self, @@ -227,7 +230,7 @@ def __init__( # noqa: PLR0915 def astep(self, _): variable_inclusion = np.zeros(self.num_variates, dtype="int") - upper = min(self.lower + self.batch[~self.tune], self.m) + upper = min(self.lower + self.batch[not self.tune], self.m) tree_ids = range(self.lower, upper) self.lower = upper if upper < self.m else 0 diff --git a/requirements.txt b/requirements.txt index e741cef..ac9bd07 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -pymc<=5.16.2 +pymc>=5.16.2, <=5.18 arviz>=0.18.0 numba matplotlib diff --git a/tests/test_bart.py b/tests/test_bart.py index dfbd86f..2ecd52c 100644 --- a/tests/test_bart.py +++ b/tests/test_bart.py @@ -3,7 +3,7 @@ import pytest from numpy.testing import assert_almost_equal, assert_array_equal from pymc.initial_point import make_initial_point_fn -from pymc.logprob.basic import joint_logp +from pymc.logprob.basic import transformed_conditional_logp import pymc_bart as pmb @@ -12,7 +12,7 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True): fn = make_initial_point_fn( model=model, return_transformed=False, - default_strategy="moment", + default_strategy="support_point", ) moment = fn(0)["x"] expected = np.asarray(expected) @@ -27,7 +27,7 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True): if check_finite_logp: logp_moment = ( - joint_logp( + transformed_conditional_logp( (model["x"],), rvs_to_values={model["x"]: pm.math.constant(moment)}, rvs_to_transforms={}, @@ -91,7 +91,7 @@ def test_shared_variable(response): Y = np.random.normal(0, 1, size=50) with pm.Model() as model: - data_X = pm.MutableData("data_X", X) + data_X = pm.Data("data_X", X) mu = pmb.BART("mu", data_X, Y, m=2, response=response) sigma = pm.HalfNormal("sigma", 1) y = pm.Normal("y", mu, sigma, observed=Y, shape=mu.shape) From ed7c4d0b3415ebb33261b236e4ef07530211034c Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Thu, 7 Nov 2024 13:22:35 -0300 Subject: [PATCH 2/3] update version --- pymc_bart/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_bart/__init__.py b/pymc_bart/__init__.py index c10b8f8..8774803 100644 --- a/pymc_bart/__init__.py +++ b/pymc_bart/__init__.py @@ -36,7 +36,7 @@ "plot_pdp", "plot_variable_importance", ] -__version__ = "0.7.0" +__version__ = "0.7.1" pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART] From 99209df878ddc727d493367bb2d45b2ff3c17d11 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Thu, 7 Nov 2024 16:42:35 -0300 Subject: [PATCH 3/3] fix shapes --- pymc_bart/bart.py | 11 +++++------ tests/test_bart.py | 10 +++++----- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index 1f21280..a21bda5 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -37,20 +37,22 @@ class BARTRV(RandomVariable): """Base class for BART.""" name: str = "BART" - signature = "(n,d),(n),(),(),(),(n)->(n)" + signature = "(m,n),(m),(),(),() -> (m)" dtype: str = "floatX" _print_name: Tuple[str, str] = ("BART", "\\operatorname{BART}") all_trees = List[List[List[Tree]]] def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): # pylint: disable=arguments-renamed - return dist_params[0].shape[:1] + idx = dist_params[0].ndim - 2 + return [dist_params[0].shape[idx]] @classmethod def rng_fn( # pylint: disable=W0237 - cls, rng=None, X=None, Y=None, m=None, alpha=None, beta=None, split_prior=None, size=None + cls, rng=None, X=None, Y=None, m=None, alpha=None, beta=None, size=None ): if not size: size = None + if not cls.all_trees: if size is not None: return np.full((size[0], cls.Y.shape[0]), cls.Y.mean()) @@ -97,9 +99,6 @@ class BART(Distribution): List of SplitRule objects, one per column in input data. Allows using different split rules for different columns. Default is ContinuousSplitRule. Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables. - shape: : Optional[Tuple], default None - Specify the output shape. If shape is different from (len(X)) (the default), train a - separate tree for each value in other dimensions. separate_trees : Optional[bool], default False When training multiple trees (by setting a shape parameter), the default behavior is to learn a joint tree structure and only have different leaf values for each. diff --git a/tests/test_bart.py b/tests/test_bart.py index 2ecd52c..e56735e 100644 --- a/tests/test_bart.py +++ b/tests/test_bart.py @@ -53,7 +53,7 @@ def test_bart_vi(response): mu = pmb.BART("mu", X, Y, m=10, response=response) sigma = pm.HalfNormal("sigma", 1) y = pm.Normal("y", mu, sigma, observed=Y) - idata = pm.sample(random_seed=3415) + idata = pm.sample(tune=200, draws=200, random_seed=3415) var_imp = ( idata.sample_stats["variable_inclusion"] .stack(samples=("chain", "draw")) @@ -77,8 +77,8 @@ def test_missing_data(response): with pm.Model() as model: mu = pmb.BART("mu", X, Y, m=10, response=response) sigma = pm.HalfNormal("sigma", 1) - y = pm.Normal("y", mu, sigma, observed=Y) - idata = pm.sample(tune=100, draws=100, chains=1, random_seed=3415) + pm.Normal("y", mu, sigma, observed=Y) + pm.sample(tune=100, draws=100, chains=1, random_seed=3415) @pytest.mark.parametrize( @@ -116,7 +116,7 @@ def test_shape(response): with pm.Model() as model: w = pmb.BART("w", X, Y, m=2, response=response, shape=(2, 250)) y = pm.Normal("y", w[0], pm.math.abs(w[1]), observed=Y) - idata = pm.sample(random_seed=3415) + idata = pm.sample(tune=50, draws=10, random_seed=3415) assert model.initial_point()["w"].shape == (2, 250) assert idata.posterior.coords["w_dim_0"].data.size == 2 @@ -133,7 +133,7 @@ class TestUtils: mu = pmb.BART("mu", X, Y, m=10) sigma = pm.HalfNormal("sigma", 1) y = pm.Normal("y", mu, sigma, observed=Y) - idata = pm.sample(random_seed=3415) + idata = pm.sample(tune=200, draws=200, random_seed=3415) def test_sample_posterior(self): all_trees = self.mu.owner.op.all_trees