Skip to content

Commit 833790f

Browse files
authored
update to pymc 5 and pytensor (#29)
* update to pymc 5 and pytensor * fix pylint * black * fix shape
1 parent 9aa4896 commit 833790f

File tree

4 files changed

+15
-21
lines changed

4 files changed

+15
-21
lines changed

pymc_bart/bart.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,14 @@
1515
# limitations under the License.
1616

1717
from multiprocessing import Manager
18-
import aesara.tensor as at
1918
import numpy as np
20-
21-
from aeppl.logprob import _logprob
22-
from aesara.tensor.random.op import RandomVariable
23-
from aesara.tensor.var import Variable
24-
2519
from pandas import DataFrame, Series
2620

2721
from pymc.distributions.distribution import Distribution, _moment
22+
from pymc.logprob.abstract import _logprob
23+
import pytensor.tensor as pt
24+
from pytensor.tensor.random.op import RandomVariable
25+
2826

2927
from .utils import _sample_posterior
3028

@@ -42,11 +40,7 @@ class BARTRV(RandomVariable):
4240
all_trees = None
4341

4442
def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
45-
if isinstance(self.X, Variable):
46-
shape = self.X.shape[0].eval()
47-
else:
48-
shape = self.X.shape[0]
49-
return (shape,)
43+
return dist_params[0].shape[:1]
5044

5145
@classmethod
5246
def rng_fn(cls, rng=None, X=None, Y=None, m=None, alpha=None, split_prior=None, size=None):
@@ -145,11 +139,11 @@ def logp(self, x, *inputs):
145139
-------
146140
TensorVariable
147141
"""
148-
return at.zeros_like(x)
142+
return pt.zeros_like(x)
149143

150144
@classmethod
151145
def get_moment(cls, rv, size, *rv_inputs):
152-
mean = at.fill(size, rv.Y.mean())
146+
mean = pt.fill(size, rv.Y.mean())
153147
return mean
154148

155149

pymc_bart/pgbart.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@
1919

2020
import numpy as np
2121

22-
from aesara import function as aesara_function
23-
from aesara import config
24-
from aesara.tensor.var import Variable
22+
from pytensor import function as pytensor_function
23+
from pytensor import config
24+
from pytensor.tensor.var import Variable
2525

2626
from pymc.model import modelcontext
2727
from pymc.step_methods.arraystep import ArrayStepShared, Competence
28-
from pymc.aesaraf import inputvars, join_nonshared_inputs, make_shared_replacements
28+
from pymc.pytensorf import inputvars, join_nonshared_inputs, make_shared_replacements
2929

3030

3131
from pymc_bart.bart import BARTRV
@@ -656,9 +656,9 @@ def logp(point, out_vars, vars, shared): # pylint: disable=redefined-builtin
656656
vars: List
657657
containing :class:`pymc.Distribution` for the input variables
658658
shared: List
659-
containing :class:`aesara.tensor.Tensor` for depended shared data
659+
containing :class:`pytensor.tensor.Tensor` for depended shared data
660660
"""
661661
out_list, inarray0 = join_nonshared_inputs(point, out_vars, vars, shared)
662-
function = aesara_function([inarray0], out_list[0])
662+
function = pytensor_function([inarray0], out_list[0])
663663
function.trust_input = True
664664
return function

pymc_bart/tree.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from copy import deepcopy
1818

19-
from aesara import config
19+
from pytensor import config
2020
import numpy as np
2121

2222

pymc_bart/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import matplotlib.pyplot as plt
55
import numpy as np
66

7-
from aesara.tensor.var import Variable
7+
from pytensor.tensor.var import Variable
88
from scipy.interpolate import griddata
99
from scipy.signal import savgol_filter
1010
from scipy.stats import pearsonr

0 commit comments

Comments
 (0)