From 3641a52f1981f0ce917e052f7253876f72f4e3d5 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Wed, 4 Dec 2024 14:10:13 -0500 Subject: [PATCH] Check if Y is a shared var in rng_fn --- pymc_bart/bart.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index 16a856c..94b91c3 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -25,6 +25,7 @@ from pymc.distributions.distribution import Distribution, _support_point from pymc.logprob.abstract import _logprob from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.sharedvar import TensorSharedVariable from .split_rules import SplitRule from .tree import Tree @@ -53,11 +54,16 @@ def rng_fn( # pylint: disable=W0237 if not size: size = None + if isinstance(cls.Y, TensorSharedVariable): + Y = cls.Y.eval() + else: + Y = cls.Y + if not cls.all_trees: if size is not None: - return np.full((size[0], cls.Y.shape[0]), cls.Y.mean()) + return np.full((size[0], Y.shape[0]), Y.mean()) else: - return np.full(cls.Y.shape[0], cls.Y.mean()) + return np.full(Y.shape[0], Y.mean()) else: if size is not None: shape = size[0]