diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index eb869d2..decb499 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -26,6 +26,7 @@ from pymc.logprob.abstract import _logprob from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.sharedvar import TensorSharedVariable +from pytensor.tensor.variable import TensorVariable from .split_rules import SplitRule from .tree import Tree @@ -54,7 +55,7 @@ def rng_fn( # pylint: disable=W0237 if not size: size = None - if isinstance(cls.Y, TensorSharedVariable): + if isinstance(cls.Y, (TensorSharedVariable, TensorVariable)): Y = cls.Y.eval() else: Y = cls.Y