15
15
# limitations under the License.
16
16
17
17
from multiprocessing import Manager
18
- import aesara .tensor as at
19
18
import numpy as np
20
-
21
- from aeppl .logprob import _logprob
22
- from aesara .tensor .random .op import RandomVariable
23
- from aesara .tensor .var import Variable
24
-
25
19
from pandas import DataFrame , Series
26
20
27
21
from pymc .distributions .distribution import Distribution , _moment
22
+ from pymc .logprob .abstract import _logprob
23
+ import pytensor .tensor as pt
24
+ from pytensor .tensor .random .op import RandomVariable
25
+
28
26
29
27
from .utils import _sample_posterior
30
28
@@ -42,11 +40,7 @@ class BARTRV(RandomVariable):
42
40
all_trees = None
43
41
44
42
def _supp_shape_from_params (self , dist_params , rep_param_idx = 1 , param_shapes = None ):
45
- if isinstance (self .X , Variable ):
46
- shape = self .X .shape [0 ].eval ()
47
- else :
48
- shape = self .X .shape [0 ]
49
- return (shape ,)
43
+ return dist_params [0 ].shape [:1 ]
50
44
51
45
@classmethod
52
46
def rng_fn (cls , rng = None , X = None , Y = None , m = None , alpha = None , split_prior = None , size = None ):
@@ -145,11 +139,11 @@ def logp(self, x, *inputs):
145
139
-------
146
140
TensorVariable
147
141
"""
148
- return at .zeros_like (x )
142
+ return pt .zeros_like (x )
149
143
150
144
@classmethod
151
145
def get_moment (cls , rv , size , * rv_inputs ):
152
- mean = at .fill (size , rv .Y .mean ())
146
+ mean = pt .fill (size , rv .Y .mean ())
153
147
return mean
154
148
155
149
0 commit comments