Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pymc_bart/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
"plot_variable_importance",
"plot_variable_inclusion",
]
__version__ = "0.7.1"
__version__ = "0.8.0"


pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART]
2 changes: 1 addition & 1 deletion pymc_bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def get_moment(rv, size, *rv_inputs):
return cls.get_moment(rv, size, *rv_inputs)

cls.rv_op = bart_op
params = [X, Y, m, alpha, beta, split_prior]
params = [X, Y, m, alpha, beta]
return super().__new__(cls, name, *params, **kwargs)

@classmethod
Expand Down
12 changes: 8 additions & 4 deletions pymc_bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
import numpy.typing as npt
from numba import njit
from pymc.initial_point import PointType
from pymc.model import Model, modelcontext
from pymc.pytensorf import inputvars, join_nonshared_inputs, make_shared_replacements
from pymc.step_methods.arraystep import ArrayStepShared
Expand Down Expand Up @@ -125,9 +126,12 @@ def __init__( # noqa: PLR0915
num_particles: int = 10,
batch: tuple[float, float] = (0.1, 0.1),
model: Optional[Model] = None,
initial_point: PointType | None = None,
compile_kwargs: dict | None = None, # pylint: disable=unused-argument
):
model = modelcontext(model)
initial_values = model.initial_point()
if initial_point is None:
initial_point = model.initial_point()
if vars is None:
vars = model.value_vars
else:
Expand All @@ -150,7 +154,7 @@ def __init__( # noqa: PLR0915
self.m = self.bart.m
self.response = self.bart.response

shape = initial_values[value_bart.name].shape
shape = initial_point[value_bart.name].shape

self.shape = 1 if len(shape) == 1 else shape[0]

Expand Down Expand Up @@ -217,8 +221,8 @@ def __init__( # noqa: PLR0915

self.num_particles = num_particles
self.indices = list(range(1, num_particles))
shared = make_shared_replacements(initial_values, vars, model)
self.likelihood_logp = logp(initial_values, [model.datalogp], vars, shared)
shared = make_shared_replacements(initial_point, vars, model)
self.likelihood_logp = logp(initial_point, [model.datalogp], vars, shared)
self.all_particles = [
[ParticleTree(self.a_tree) for _ in range(self.m)] for _ in range(self.trees_shape)
]
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pymc>=5.16.2, <=5.18
pymc>=5.16.2, <=5.19.1
arviz>=0.18.0
numba
matplotlib
Expand Down
6 changes: 4 additions & 2 deletions tests/test_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,10 @@ def test_categorical_model(separate_trees, split_rule):
separate_trees=separate_trees,
)
y = pm.Categorical("y", p=pm.math.softmax(lo.T, axis=-1), observed=Y)
idata = pm.sample(random_seed=3415, tune=300, draws=300)
idata = pm.sample_posterior_predictive(idata, predictions=True, extend_inferencedata=True)
idata = pm.sample(tune=300, draws=300, random_seed=3415)
idata = pm.sample_posterior_predictive(
idata, predictions=True, extend_inferencedata=True, random_seed=3415
)

# Fit should be good enough so right category is selected over 50% of time
assert (idata.predictions.y.median(["chain", "draw"]) == Y).all()
Loading