diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index 16a856c..94b91c3 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -25,6 +25,7 @@ from pymc.distributions.distribution import Distribution, _support_point from pymc.logprob.abstract import _logprob from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.sharedvar import TensorSharedVariable from .split_rules import SplitRule from .tree import Tree @@ -53,11 +54,16 @@ def rng_fn( # pylint: disable=W0237 if not size: size = None + if isinstance(cls.Y, TensorSharedVariable): + Y = cls.Y.eval() + else: + Y = cls.Y + if not cls.all_trees: if size is not None: - return np.full((size[0], cls.Y.shape[0]), cls.Y.mean()) + return np.full((size[0], Y.shape[0]), Y.mean()) else: - return np.full(cls.Y.shape[0], cls.Y.mean()) + return np.full(Y.shape[0], Y.mean()) else: if size is not None: shape = size[0]