Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pymc_bart/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"plot_pdp",
"plot_variable_importance",
]
__version__ = "0.7.0"
__version__ = "0.7.1"


pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART]
14 changes: 7 additions & 7 deletions pymc_bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,22 @@ class BARTRV(RandomVariable):
"""Base class for BART."""

name: str = "BART"
ndim_supp = 1
ndims_params: List[int] = [2, 1, 0, 0, 0, 1]
signature = "(m,n),(m),(),(),() -> (m)"
dtype: str = "floatX"
_print_name: Tuple[str, str] = ("BART", "\\operatorname{BART}")
all_trees = List[List[List[Tree]]]

def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): # pylint: disable=arguments-renamed
return dist_params[0].shape[:1]
idx = dist_params[0].ndim - 2
return [dist_params[0].shape[idx]]

@classmethod
def rng_fn( # pylint: disable=W0237
cls, rng=None, X=None, Y=None, m=None, alpha=None, beta=None, split_prior=None, size=None
cls, rng=None, X=None, Y=None, m=None, alpha=None, beta=None, size=None
):
if not size:
size = None

if not cls.all_trees:
if size is not None:
return np.full((size[0], cls.Y.shape[0]), cls.Y.mean())
Expand Down Expand Up @@ -96,9 +99,6 @@ class BART(Distribution):
List of SplitRule objects, one per column in input data.
Allows using different split rules for different columns. Default is ContinuousSplitRule.
Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables.
shape: : Optional[Tuple], default None
Specify the output shape. If shape is different from (len(X)) (the default), train a
separate tree for each value in other dimensions.
separate_trees : Optional[bool], default False
When training multiple trees (by setting a shape parameter), the default behavior is to
learn a joint tree structure and only have different leaf values for each.
Expand Down
7 changes: 5 additions & 2 deletions pymc_bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,10 @@ class PGBART(ArrayStepShared):
name = "pgbart"
default_blocked = False
generates_stats = True
stats_dtypes = [{"variable_inclusion": object, "tune": bool}]
stats_dtypes_shapes: dict[str, tuple[type, list]] = {
"variable_inclusion": (object, []),
"tune": (bool, []),
}

def __init__( # noqa: PLR0915
self,
Expand Down Expand Up @@ -227,7 +230,7 @@ def __init__( # noqa: PLR0915
def astep(self, _):
variable_inclusion = np.zeros(self.num_variates, dtype="int")

upper = min(self.lower + self.batch[~self.tune], self.m)
upper = min(self.lower + self.batch[not self.tune], self.m)
tree_ids = range(self.lower, upper)
self.lower = upper if upper < self.m else 0

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pymc<=5.16.2
pymc>=5.16.2, <=5.18
arviz>=0.18.0
numba
matplotlib
Expand Down
18 changes: 9 additions & 9 deletions tests/test_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from numpy.testing import assert_almost_equal, assert_array_equal
from pymc.initial_point import make_initial_point_fn
from pymc.logprob.basic import joint_logp
from pymc.logprob.basic import transformed_conditional_logp

import pymc_bart as pmb

Expand All @@ -12,7 +12,7 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True):
fn = make_initial_point_fn(
model=model,
return_transformed=False,
default_strategy="moment",
default_strategy="support_point",
)
moment = fn(0)["x"]
expected = np.asarray(expected)
Expand All @@ -27,7 +27,7 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True):

if check_finite_logp:
logp_moment = (
joint_logp(
transformed_conditional_logp(
(model["x"],),
rvs_to_values={model["x"]: pm.math.constant(moment)},
rvs_to_transforms={},
Expand All @@ -53,7 +53,7 @@ def test_bart_vi(response):
mu = pmb.BART("mu", X, Y, m=10, response=response)
sigma = pm.HalfNormal("sigma", 1)
y = pm.Normal("y", mu, sigma, observed=Y)
idata = pm.sample(random_seed=3415)
idata = pm.sample(tune=200, draws=200, random_seed=3415)
var_imp = (
idata.sample_stats["variable_inclusion"]
.stack(samples=("chain", "draw"))
Expand All @@ -77,8 +77,8 @@ def test_missing_data(response):
with pm.Model() as model:
mu = pmb.BART("mu", X, Y, m=10, response=response)
sigma = pm.HalfNormal("sigma", 1)
y = pm.Normal("y", mu, sigma, observed=Y)
idata = pm.sample(tune=100, draws=100, chains=1, random_seed=3415)
pm.Normal("y", mu, sigma, observed=Y)
pm.sample(tune=100, draws=100, chains=1, random_seed=3415)


@pytest.mark.parametrize(
Expand All @@ -91,7 +91,7 @@ def test_shared_variable(response):
Y = np.random.normal(0, 1, size=50)

with pm.Model() as model:
data_X = pm.MutableData("data_X", X)
data_X = pm.Data("data_X", X)
mu = pmb.BART("mu", data_X, Y, m=2, response=response)
sigma = pm.HalfNormal("sigma", 1)
y = pm.Normal("y", mu, sigma, observed=Y, shape=mu.shape)
Expand All @@ -116,7 +116,7 @@ def test_shape(response):
with pm.Model() as model:
w = pmb.BART("w", X, Y, m=2, response=response, shape=(2, 250))
y = pm.Normal("y", w[0], pm.math.abs(w[1]), observed=Y)
idata = pm.sample(random_seed=3415)
idata = pm.sample(tune=50, draws=10, random_seed=3415)

assert model.initial_point()["w"].shape == (2, 250)
assert idata.posterior.coords["w_dim_0"].data.size == 2
Expand All @@ -133,7 +133,7 @@ class TestUtils:
mu = pmb.BART("mu", X, Y, m=10)
sigma = pm.HalfNormal("sigma", 1)
y = pm.Normal("y", mu, sigma, observed=Y)
idata = pm.sample(random_seed=3415)
idata = pm.sample(tune=200, draws=200, random_seed=3415)

def test_sample_posterior(self):
all_trees = self.mu.owner.op.all_trees
Expand Down
Loading