Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pymc_bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pymc.logprob.abstract import _logprob
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.sharedvar import TensorSharedVariable
from pytensor.tensor.variable import TensorVariable

from .split_rules import SplitRule
from .tree import Tree
Expand Down Expand Up @@ -54,7 +55,7 @@ def rng_fn( # pylint: disable=W0237
if not size:
size = None

if isinstance(cls.Y, TensorSharedVariable):
if isinstance(cls.Y, TensorSharedVariable) or isinstance(cls.Y, TensorVariable):
Y = cls.Y.eval()
else:
Y = cls.Y
Expand Down
Loading