Skip to content

Commit 61a63b2

Browse files
committed
Update PyMC dependency
1 parent e96880f commit 61a63b2

File tree

12 files changed

+114
-124
lines changed

12 files changed

+114
-124
lines changed

.github/workflows/test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
runs-on: ${{ matrix.os }}
2727
env:
2828
TEST_SUBSET: ${{ matrix.test-subset }}
29-
AESARA_FLAGS: floatX=${{ matrix.floatx }},gcc__cxxflags='-march=native'
29+
PYTENSOR_FLAGS: floatX=${{ matrix.floatx }},gcc__cxxflags='-march=native'
3030
defaults:
3131
run:
3232
shell: bash -l {0}
@@ -91,7 +91,7 @@ jobs:
9191
runs-on: ${{ matrix.os }}
9292
env:
9393
TEST_SUBSET: ${{ matrix.test-subset }}
94-
AESARA_FLAGS: floatX=${{ matrix.floatx }},gcc__cxxflags='-march=core2'
94+
PYTENSOR_FLAGS: floatX=${{ matrix.floatx }},gcc__cxxflags='-march=core2'
9595
defaults:
9696
run:
9797
shell: cmd

pymc_experimental/distributions/continuous.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121

2222
from typing import List, Tuple, Union
2323

24-
import aesara.tensor as at
2524
import numpy as np
26-
from aesara.tensor.random.op import RandomVariable
27-
from aesara.tensor.var import TensorVariable
28-
from pymc.aesaraf import floatX
25+
import pytensor.tensor as pt
2926
from pymc.distributions.dist_math import check_parameters
3027
from pymc.distributions.distribution import Continuous
3128
from pymc.distributions.shape_utils import rv_size_is_none
29+
from pymc.pytensorf import floatX
30+
from pytensor.tensor.random.op import RandomVariable
31+
from pytensor.tensor.var import TensorVariable
3232
from scipy import stats
3333

3434

@@ -144,9 +144,9 @@ def dist(cls, mu=0, sigma=1, xi=0, scipy=False, **kwargs):
144144
# If SciPy, use its parametrization, otherwise convert to standard
145145
if scipy:
146146
xi = -xi
147-
mu = at.as_tensor_variable(floatX(mu))
148-
sigma = at.as_tensor_variable(floatX(sigma))
149-
xi = at.as_tensor_variable(floatX(xi))
147+
mu = pt.as_tensor_variable(floatX(mu))
148+
sigma = pt.as_tensor_variable(floatX(sigma))
149+
xi = pt.as_tensor_variable(floatX(xi))
150150

151151
return super().dist([mu, sigma, xi], **kwargs)
152152

@@ -159,26 +159,26 @@ def logp(value, mu, sigma, xi):
159159
----------
160160
value: numeric
161161
Value(s) for which log-probability is calculated. If the log probabilities for multiple
162-
values are desired the values must be provided in a numpy array or Aesara tensor
162+
values are desired the values must be provided in a numpy array or Pytensor tensor
163163
164164
Returns
165165
-------
166166
TensorVariable
167167
"""
168168
scaled = (value - mu) / sigma
169169

170-
logp_expression = at.switch(
171-
at.isclose(xi, 0),
172-
-at.log(sigma) - scaled - at.exp(-scaled),
173-
-at.log(sigma)
174-
- ((xi + 1) / xi) * at.log1p(xi * scaled)
175-
- at.pow(1 + xi * scaled, -1 / xi),
170+
logp_expression = pt.switch(
171+
pt.isclose(xi, 0),
172+
-pt.log(sigma) - scaled - pt.exp(-scaled),
173+
-pt.log(sigma)
174+
- ((xi + 1) / xi) * pt.log1p(xi * scaled)
175+
- pt.pow(1 + xi * scaled, -1 / xi),
176176
)
177177

178-
logp = at.switch(at.gt(1 + xi * scaled, 0.0), logp_expression, -np.inf)
178+
logp = pt.switch(pt.gt(1 + xi * scaled, 0.0), logp_expression, -np.inf)
179179

180180
return check_parameters(
181-
logp, sigma > 0, at.and_(xi > -1, xi < 1), msg="sigma > 0 or -1 < xi < 1"
181+
logp, sigma > 0, pt.and_(xi > -1, xi < 1), msg="sigma > 0 or -1 < xi < 1"
182182
)
183183

184184
def logcdf(value, mu, sigma, xi):
@@ -198,21 +198,21 @@ def logcdf(value, mu, sigma, xi):
198198
TensorVariable
199199
"""
200200
scaled = (value - mu) / sigma
201-
logc_expression = at.switch(
202-
at.isclose(xi, 0), -at.exp(-scaled), -at.pow(1 + xi * scaled, -1 / xi)
201+
logc_expression = pt.switch(
202+
pt.isclose(xi, 0), -pt.exp(-scaled), -pt.pow(1 + xi * scaled, -1 / xi)
203203
)
204204

205-
logc = at.switch(1 + xi * (value - mu) / sigma > 0, logc_expression, -np.inf)
205+
logc = pt.switch(1 + xi * (value - mu) / sigma > 0, logc_expression, -np.inf)
206206

207207
return check_parameters(
208-
logc, sigma > 0, at.and_(xi > -1, xi < 1), msg="sigma > 0 or -1 < xi < 1"
208+
logc, sigma > 0, pt.and_(xi > -1, xi < 1), msg="sigma > 0 or -1 < xi < 1"
209209
)
210210

211211
def moment(rv, size, mu, sigma, xi):
212212
r"""
213213
Using the mode, as the mean can be infinite when :math:`\xi > 1`
214214
"""
215-
mode = at.switch(at.isclose(xi, 0), mu, mu + sigma * (at.pow(1 + xi, -xi) - 1) / xi)
215+
mode = pt.switch(pt.isclose(xi, 0), mu, mu + sigma * (pt.pow(1 + xi, -xi) - 1) / xi)
216216
if not rv_size_is_none(size):
217-
mode = at.full(size, mode)
217+
mode = pt.full(size, mode)
218218
return mode

pymc_experimental/distributions/histogram_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,15 @@ def histogram_approximation(name, dist, *, observed, **h_kwargs):
9999
----------
100100
name : str
101101
Name for the Potential
102-
dist : aesara.tensor.var.TensorVariable
102+
dist : pytensor.tensor.var.TensorVariable
103103
The output of pm.Distribution.dist()
104104
observed : ArrayLike
105105
Observed value to construct a histogram. Histogram is computed over 0th axis.
106106
Dask is supported.
107107
108108
Returns
109109
-------
110-
aesara.tensor.var.TensorVariable
110+
pytensor.tensor.var.TensorVariable
111111
Potential
112112
113113
Examples

pymc_experimental/gp/latent_approx.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
# limitations under the License.
1414

1515

16-
import aesara.tensor as at
1716
import numpy as np
1817
import pymc as pm
18+
import pytensor.tensor as pt
1919
from pymc.gp.util import JITTER_DEFAULT, cholesky, solve_lower, solve_upper, stabilize
2020

2121

@@ -42,7 +42,7 @@ def _build_prior(self, name, X, Xu, jitter=JITTER_DEFAULT, **kwargs):
4242
u = pm.Deterministic(name + "_u", L @ v)
4343

4444
Kfu = self.cov_func(X, Xu)
45-
Kuuiu = solve_upper(at.transpose(L), solve_lower(L, u))
45+
Kuuiu = solve_upper(pt.transpose(L), solve_lower(L, u))
4646

4747
return pm.Deterministic(name, mu + Kfu @ Kuuiu), Kuuiu, L
4848

@@ -62,8 +62,8 @@ def prior(self, name, X, Xu=None, jitter=JITTER_DEFAULT, **kwargs):
6262
def _build_conditional(self, name, Xnew, Xu, L, Kuuiu, jitter, **kwargs):
6363
Ksu = self.cov_func(Xnew, Xu)
6464
mu = self.mean_func(Xnew) + Ksu @ Kuuiu
65-
tmp = solve_lower(L, at.transpose(Ksu))
66-
Qss = at.transpose(tmp) @ tmp # Qss = tt.dot(tt.dot(Ksu, tt.nlinalg.pinv(Kuu)), Ksu.T)
65+
tmp = solve_lower(L, pt.transpose(Ksu))
66+
Qss = pt.transpose(tmp) @ tmp # Qss = tt.dot(tt.dot(Ksu, tt.nlinalg.pinv(Kuu)), Ksu.T)
6767
Kss = self.cov_func(Xnew)
6868
Lss = cholesky(stabilize(Kss - Qss, jitter))
6969
return mu, Lss
@@ -100,42 +100,42 @@ def prior(self, name, X, **kwargs):
100100
return f
101101

102102
def _generate_basis(self, X, L):
103-
indices = at.arange(1, self.M + 1)
104-
m1 = (np.pi / (2.0 * L)) * at.tile(L + X, self.M)
105-
m2 = at.diag(indices)
106-
Phi = at.sin(m1 @ m2) / at.sqrt(L)
103+
indices = pt.arange(1, self.M + 1)
104+
m1 = (np.pi / (2.0 * L)) * pt.tile(L + X, self.M)
105+
m2 = pt.diag(indices)
106+
Phi = pt.sin(m1 @ m2) / pt.sqrt(L)
107107
omega = (np.pi * indices) / (2.0 * L)
108108
return Phi, omega
109109

110110
def _build_prior(self, name, X, **kwargs):
111111
n_obs = np.shape(X)[0]
112112

113113
# standardize input scale
114-
X = at.as_tensor_variable(X)
115-
Xmu = at.mean(X, axis=0)
116-
Xsd = at.std(X, axis=0)
114+
X = pt.as_tensor_variable(X)
115+
Xmu = pt.mean(X, axis=0)
116+
Xsd = pt.std(X, axis=0)
117117
Xz = (X - Xmu) / Xsd
118118

119119
# define L using Xz and c
120-
La = at.abs(at.min(Xz)) # .eval()?
121-
Lb = at.max(Xz)
122-
L = self.c * at.max([La, Lb])
120+
La = pt.abs(pt.min(Xz)) # .eval()?
121+
Lb = pt.max(Xz)
122+
L = self.c * pt.max([La, Lb])
123123

124124
# make basis and omega, spectral density
125125
Phi, omega = self._generate_basis(Xz, L)
126126
scale, ls, spectral_density = self._validate_cov_func(self.cov_func)
127127
spd = scale * spectral_density(omega, ls / Xsd).flatten()
128128

129129
beta = pm.Normal(f"{name}_coeffs_", size=self.M)
130-
f = pm.Deterministic(name, self.mean_func(X) + at.dot(Phi * at.sqrt(spd), beta))
130+
f = pm.Deterministic(name, self.mean_func(X) + pt.dot(Phi * pt.sqrt(spd), beta))
131131
return f, Phi, L, spd, beta, Xmu, Xsd
132132

133133
def _build_conditional(self, Xnew, Xmu, Xsd, L, beta):
134134
Xnewz = (Xnew - Xmu) / Xsd
135135
Phi, omega = self._generate_basis(Xnewz, L)
136136
scale, ls, spectral_density = self._validate_cov_func(self.cov_func)
137137
spd = scale * spectral_density(omega, ls / Xsd).flatten()
138-
return self.mean_func(Xnew) + at.dot(Phi * at.sqrt(spd), beta)
138+
return self.mean_func(Xnew) + pt.dot(Phi * pt.sqrt(spd), beta)
139139

140140
def conditional(self, name, Xnew):
141141
# warn about extrapolation
@@ -147,15 +147,15 @@ class ExpQuad(pm.gp.cov.ExpQuad):
147147
@staticmethod
148148
def spectral_density(omega, ls):
149149
# univariate spectral denisty, implement multi
150-
return at.sqrt(2 * np.pi) * ls * at.exp(-0.5 * ls**2 * omega**2)
150+
return pt.sqrt(2 * np.pi) * ls * pt.exp(-0.5 * ls**2 * omega**2)
151151

152152

153153
class Matern52(pm.gp.cov.Matern52):
154154
@staticmethod
155155
def spectral_density(omega, ls):
156156
# univariate spectral denisty, implement multi
157157
# https://arxiv.org/pdf/1611.06740.pdf
158-
lam = at.sqrt(5) * (1.0 / ls)
158+
lam = pt.sqrt(5) * (1.0 / ls)
159159
return (16.0 / 3.0) * lam**5 * (1.0 / (lam**2 + omega**2) ** 3)
160160

161161

@@ -165,7 +165,7 @@ def spectral_density(omega, ls):
165165
# univariate spectral denisty, implement multi
166166
# https://arxiv.org/pdf/1611.06740.pdf
167167
lam = np.sqrt(3.0) * (1.0 / ls)
168-
return 4.0 * lam**3 * (1.0 / at.square(lam**2 + omega**2))
168+
return 4.0 * lam**3 * (1.0 / pt.square(lam**2 + omega**2))
169169

170170

171171
class Matern12(pm.gp.cov.Matern12):
@@ -193,7 +193,7 @@ def __init__(
193193
def _build_prior(self, name, X, jitter=1e-6, **kwargs):
194194
mu = self.mean_func(X)
195195
Kxx = pm.gp.util.stabilize(self.cov_func(X), jitter)
196-
vals, vecs = at.linalg.eigh(Kxx)
196+
vals, vecs = pt.linalg.eigh(Kxx)
197197
## NOTE: REMOVED PRECISION CUTOFF
198198
if self.variance_limit is None:
199199
n_eigs = self.n_eigs
@@ -204,7 +204,7 @@ def _build_prior(self, name, X, jitter=1e-6, **kwargs):
204204
n_eigs = ((vals[::-1].cumsum() / vals.sum()) > self.variance_limit).nonzero()[0][0]
205205
U = vecs[:, -n_eigs:]
206206
s = vals[-n_eigs:]
207-
basis = U * at.sqrt(s)
207+
basis = U * pt.sqrt(s)
208208

209209
coefs_raw = pm.Normal(f"_gp_{name}_coefs", mu=0, sigma=1, size=n_eigs)
210210
# weight = pm.HalfNormal(f"_gp_{name}_sd")
@@ -222,7 +222,7 @@ def prior(self, name, X, jitter=1e-6, **kwargs):
222222
def _build_conditional(self, Xnew, X, f, U, s, jitter):
223223
Kxs = self.cov_func(X, Xnew)
224224
Kss = self.cov_func(Xnew)
225-
Kxxpinv = U @ at.diag(1.0 / s) @ U.T
225+
Kxxpinv = U @ pt.diag(1.0 / s) @ U.T
226226
mus = Kxs.T @ Kxxpinv @ f
227227
K = Kss - Kxs.T @ Kxxpinv @ Kxs
228228
L = pm.gp.util.cholesky(pm.gp.util.stabilize(K, jitter))

pymc_experimental/inference/pathfinder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def fit_pathfinder(
105105
init_position_dict = model.initial_point()
106106
init_position = [init_position_dict[rv] for rv in rvs]
107107

108-
new_logprob, new_input = pm.aesaraf.join_nonshared_inputs(
108+
new_logprob, new_input = pm.pytensorf.join_nonshared_inputs(
109109
init_position_dict, (model.logp(),), model.value_vars, ()
110110
)
111111

pymc_experimental/marginal_model.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,21 @@
11
import warnings
22
from typing import Sequence, Tuple, Union
33

4-
import aesara.tensor as at
54
import numpy as np
6-
from aeppl import factorized_joint_logprob
7-
from aeppl.abstract import _get_measurable_outputs
8-
from aeppl.logprob import _logprob
9-
from aesara import Mode
10-
from aesara.compile import SharedVariable
11-
from aesara.compile.builders import OpFromGraph
12-
from aesara.graph import Constant, FunctionGraph, ancestors, clone_replace
13-
from aesara.scan import map as scan_map
14-
from aesara.tensor import TensorVariable
15-
from aesara.tensor.elemwise import Elemwise
5+
import pytensor.tensor as pt
166
from pymc import SymbolicRandomVariable
17-
from pymc.aesaraf import constant_fold, inputvars
187
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
8+
from pymc.logprob.abstract import _get_measurable_outputs, _logprob
9+
from pymc.logprob.joint_logprob import factorized_joint_logprob
1910
from pymc.model import Model
11+
from pymc.pytensorf import constant_fold, inputvars
12+
from pytensor import Mode
13+
from pytensor.compile import SharedVariable
14+
from pytensor.compile.builders import OpFromGraph
15+
from pytensor.graph import Constant, FunctionGraph, ancestors, clone_replace
16+
from pytensor.scan import map as scan_map
17+
from pytensor.tensor import TensorVariable
18+
from pytensor.tensor.elemwise import Elemwise
2019

2120
__all__ = ["MarginalModel"]
2221

@@ -381,7 +380,7 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]:
381380
return (0, 1)
382381
elif isinstance(op, Categorical):
383382
p_param = rv.owner.inputs[3]
384-
return tuple(range(at.get_vector_length(p_param)))
383+
return tuple(range(pt.get_vector_length(p_param)))
385384
elif isinstance(op, DiscreteUniform):
386385
lower, upper = constant_fold(rv.owner.inputs[3:])
387386
return tuple(range(lower, upper + 1))
@@ -437,8 +436,8 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
437436
# PyMC does not allow RVs in the logp graph, even if we are just using the shape
438437
marginalized_rv_shape = constant_fold(tuple(marginalized_rv.shape))
439438
marginalized_rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv)
440-
marginalized_rv_domain_tensor = at.swapaxes(
441-
at.full(
439+
marginalized_rv_domain_tensor = pt.swapaxes(
440+
pt.full(
442441
(*marginalized_rv_shape, len(marginalized_rv_domain)),
443442
marginalized_rv_domain,
444443
dtype=marginalized_rv.dtype,
@@ -459,7 +458,7 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
459458
]
460459
else:
461460
# Make sure this is rewrite is registered
462-
from pymc.aesaraf import local_remove_check_parameter
461+
from pymc.pytensorf import local_remove_check_parameter
463462

464463
def logp_fn(marginalized_rv_const, *non_sequences):
465464
return joint_logp_op(marginalized_rv_const, *non_sequences)
@@ -471,7 +470,7 @@ def logp_fn(marginalized_rv_const, *non_sequences):
471470
mode=Mode().including("local_remove_check_parameter"),
472471
)
473472

474-
joint_logps = at.logsumexp(joint_logps, axis=0)
473+
joint_logps = pt.logsumexp(joint_logps, axis=0)
475474

476-
# We have to add dummy logps for the remaining value variables, otherwise AePPL will raise
477-
return joint_logps, *(at.constant(0),) * (len(values) - 1)
475+
# We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
476+
return joint_logps, *(pt.constant(0),) * (len(values) - 1)

pymc_experimental/tests/distributions/test_continuous.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
# general imports
16-
import aesara
1715
import numpy as np
1816
import pymc as pm
17+
18+
# general imports
19+
import pytensor
1920
import pytest
2021
import scipy.stats.distributions as sp
2122

@@ -45,7 +46,7 @@ class TestGenExtremeClass:
4546
"""
4647

4748
@pytest.mark.xfail(
48-
condition=(aesara.config.floatX == "float32"),
49+
condition=(pytensor.config.floatX == "float32"),
4950
reason="PyMC underflows earlier than scipy on float32",
5051
)
5152
def test_logp(self):
@@ -62,7 +63,7 @@ def test_logp(self):
6263
else -np.inf,
6364
)
6465

65-
if aesara.config.floatX == "float32":
66+
if pytensor.config.floatX == "float32":
6667
raise Exception("Flaky test: It passed this time, but XPASS is not allowed.")
6768

6869
def test_logcdf(self):

0 commit comments

Comments
 (0)