Skip to content

Commit 3364ed3

Browse files
authored
Check if Y is a shared var in rng_fn (#202)
1 parent 07f55d4 commit 3364ed3

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

pymc_bart/bart.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from pymc.distributions.distribution import Distribution, _support_point
2626
from pymc.logprob.abstract import _logprob
2727
from pytensor.tensor.random.op import RandomVariable
28+
from pytensor.tensor.sharedvar import TensorSharedVariable
2829

2930
from .split_rules import SplitRule
3031
from .tree import Tree
@@ -53,11 +54,16 @@ def rng_fn( # pylint: disable=W0237
5354
if not size:
5455
size = None
5556

57+
if isinstance(cls.Y, TensorSharedVariable):
58+
Y = cls.Y.eval()
59+
else:
60+
Y = cls.Y
61+
5662
if not cls.all_trees:
5763
if size is not None:
58-
return np.full((size[0], cls.Y.shape[0]), cls.Y.mean())
64+
return np.full((size[0], Y.shape[0]), Y.mean())
5965
else:
60-
return np.full(cls.Y.shape[0], cls.Y.mean())
66+
return np.full(Y.shape[0], Y.mean())
6167
else:
6268
if size is not None:
6369
shape = size[0]

0 commit comments

Comments
 (0)