Skip to content

Commit 77116d1

Browse files
Patch for case when Y is a TensorVariable (#206)
* add case tensor var for Y * Improve `isinstance` statement Co-authored-by: Osvaldo A Martin <[email protected]> --------- Co-authored-by: Osvaldo A Martin <[email protected]>
1 parent d4e8cad commit 77116d1

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

pymc_bart/bart.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pymc.logprob.abstract import _logprob
2727
from pytensor.tensor.random.op import RandomVariable
2828
from pytensor.tensor.sharedvar import TensorSharedVariable
29+
from pytensor.tensor.variable import TensorVariable
2930

3031
from .split_rules import SplitRule
3132
from .tree import Tree
@@ -54,7 +55,7 @@ def rng_fn( # pylint: disable=W0237
5455
if not size:
5556
size = None
5657

57-
if isinstance(cls.Y, TensorSharedVariable):
58+
if isinstance(cls.Y, (TensorSharedVariable, TensorVariable)):
5859
Y = cls.Y.eval()
5960
else:
6061
Y = cls.Y

0 commit comments

Comments
 (0)