File tree Expand file tree Collapse file tree 1 file changed +8
-2
lines changed Expand file tree Collapse file tree 1 file changed +8
-2
lines changed Original file line number Diff line number Diff line change 25
25
from pymc .distributions .distribution import Distribution , _support_point
26
26
from pymc .logprob .abstract import _logprob
27
27
from pytensor .tensor .random .op import RandomVariable
28
+ from pytensor .tensor .sharedvar import TensorSharedVariable
28
29
29
30
from .split_rules import SplitRule
30
31
from .tree import Tree
@@ -53,11 +54,16 @@ def rng_fn( # pylint: disable=W0237
53
54
if not size :
54
55
size = None
55
56
57
+ if isinstance (cls .Y , TensorSharedVariable ):
58
+ Y = cls .Y .eval ()
59
+ else :
60
+ Y = cls .Y
61
+
56
62
if not cls .all_trees :
57
63
if size is not None :
58
- return np .full ((size [0 ], cls . Y .shape [0 ]), cls . Y .mean ())
64
+ return np .full ((size [0 ], Y .shape [0 ]), Y .mean ())
59
65
else :
60
- return np .full (cls . Y .shape [0 ], cls . Y .mean ())
66
+ return np .full (Y .shape [0 ], Y .mean ())
61
67
else :
62
68
if size is not None :
63
69
shape = size [0 ]
You can’t perform that action at this time.
0 commit comments