Skip to content

Commit 4ca3414

Browse files
ArmavicaricardoV94
authored andcommitted
Upgrade to aesara=2.8.2 and aeppl=0.0.35
1 parent 5a7c827 commit 4ca3414

18 files changed

+51
-35
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ repos:
2626
- types-filelock
2727
- types-setuptools
2828
- arviz
29-
- aesara==2.7.9
30-
- aeppl==0.0.34
29+
- aesara==2.8.2
30+
- aeppl==0.0.35
3131
always_run: true
3232
require_serial: true
3333
pass_filenames: false

conda-envs/environment-dev.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ channels:
55
- defaults
66
dependencies:
77
# Base dependencies
8-
- aeppl=0.0.34
9-
- aesara=2.7.9
8+
- aeppl=0.0.35
9+
- aesara=2.8.2
1010
- arviz>=0.12.0
1111
- blas
1212
- cachetools>=4.2.1

conda-envs/environment-test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ channels:
55
- defaults
66
dependencies:
77
# Base dependencies
8-
- aeppl=0.0.34
9-
- aesara=2.7.9
8+
- aeppl=0.0.35
9+
- aesara=2.8.2
1010
- arviz>=0.12.0
1111
- blas
1212
- cachetools>=4.2.1

conda-envs/windows-environment-dev.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ channels:
55
- defaults
66
dependencies:
77
# Base dependencies (see install guide for Windows)
8-
- aeppl=0.0.34
9-
- aesara=2.7.9
8+
- aeppl=0.0.35
9+
- aesara=2.8.2
1010
- arviz>=0.12.0
1111
- blas
1212
- cachetools>=4.2.1

conda-envs/windows-environment-test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ channels:
55
- defaults
66
dependencies:
77
# Base dependencies (see install guide for Windows)
8-
- aeppl=0.0.34
9-
- aesara=2.7.9
8+
- aeppl=0.0.35
9+
- aesara=2.8.2
1010
- arviz>=0.12.0
1111
- blas
1212
- cachetools>=4.2.1

pymc/aesaraf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from aesara import config, scalar
3838
from aesara.compile.mode import Mode, get_mode
3939
from aesara.gradient import grad
40-
from aesara.graph import local_optimizer
40+
from aesara.graph import node_rewriter
4141
from aesara.graph.basic import (
4242
Apply,
4343
Constant,
@@ -875,7 +875,7 @@ def largest_common_dtype(tensors):
875875
return np.stack([np.ones((), dtype=dtype) for dtype in dtypes]).dtype
876876

877877

878-
@local_optimizer(tracks=[CheckParameterValue])
878+
@node_rewriter(tracks=[CheckParameterValue])
879879
def local_remove_check_parameter(fgraph, node):
880880
"""Rewrite that removes Aeppl's CheckParameterValue
881881
@@ -885,7 +885,7 @@ def local_remove_check_parameter(fgraph, node):
885885
return [node.inputs[0]]
886886

887887

888-
@local_optimizer(tracks=[CheckParameterValue])
888+
@node_rewriter(tracks=[CheckParameterValue])
889889
def local_check_parameter_to_ninf_switch(fgraph, node):
890890
if isinstance(node.op, CheckParameterValue):
891891
logp_expr, *logp_conds = node.inputs

pymc/distributions/continuous.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from aesara.tensor.math import tanh
3636
from aesara.tensor.random.basic import (
3737
BetaRV,
38-
WeibullRV,
3938
cauchy,
4039
chisquare,
4140
exponential,
@@ -1464,7 +1463,7 @@ def dist(cls, lam, *args, **kwargs):
14641463
lam = at.as_tensor_variable(floatX(lam))
14651464

14661465
# Aesara exponential op is parametrized in terms of mu (1/lam)
1467-
return super().dist([at.inv(lam)], **kwargs)
1466+
return super().dist([at.reciprocal(lam)], **kwargs)
14681467

14691468
def moment(rv, size, mu):
14701469
if not rv_size_is_none(size):
@@ -1487,7 +1486,7 @@ def logcdf(value, mu):
14871486
-------
14881487
TensorVariable
14891488
"""
1490-
lam = at.inv(mu)
1489+
lam = at.reciprocal(mu)
14911490
res = at.switch(
14921491
at.lt(value, 0),
14931492
-np.inf,
@@ -2313,7 +2312,7 @@ def logcdf(value, alpha, inv_beta):
23132312
-------
23142313
TensorVariable
23152314
"""
2316-
beta = at.inv(inv_beta)
2315+
beta = at.reciprocal(inv_beta)
23172316
res = at.switch(
23182317
at.lt(value, 0),
23192318
-np.inf,
@@ -2518,8 +2517,15 @@ def logcdf(value, nu):
25182517

25192518

25202519
# TODO: Remove this once logp for multiplication is working!
2521-
class WeibullBetaRV(WeibullRV):
2520+
class WeibullBetaRV(RandomVariable):
2521+
name = "weibull"
2522+
ndim_supp = 0
25222523
ndims_params = [0, 0]
2524+
dtype = "floatX"
2525+
_print_name = ("Weibull", "\\operatorname{Weibull}")
2526+
2527+
def __call__(self, alpha, beta, size=None, **kwargs):
2528+
return super().__call__(alpha, beta, size=size, **kwargs)
25232529

25242530
@classmethod
25252531
def rng_fn(cls, rng, alpha, beta, size) -> np.ndarray:
@@ -2615,6 +2621,16 @@ def logcdf(value, alpha, beta):
26152621

26162622
return check_parameters(res, 0 < alpha, 0 < beta, msg="alpha > 0, beta > 0")
26172623

2624+
def logp(value, alpha, beta):
2625+
res = (
2626+
at.log(alpha)
2627+
- at.log(beta)
2628+
+ (alpha - 1.0) * at.log(value / beta)
2629+
- at.pow(value / beta, alpha)
2630+
)
2631+
res = at.switch(at.ge(value, 0.0), res, -np.inf)
2632+
return check_parameters(res, 0 < alpha, 0 < beta, msg="alpha > 0, beta > 0")
2633+
26182634

26192635
class HalfStudentTRV(RandomVariable):
26202636
name = "halfstudentt"

pymc/distributions/dist_math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def sigma2rho(sigma):
157157
"""
158158
`sigma -> rho` Aesara converter
159159
:math:`mu + sigma*e = mu + log(1+exp(rho))*e`"""
160-
return at.log(at.exp(at.abs_(sigma)) - 1.0)
160+
return at.log(at.exp(at.abs(sigma)) - 1.0)
161161

162162

163163
def rho2sigma(rho):
@@ -213,7 +213,7 @@ def log_normal(x, mean, **kwargs):
213213
else:
214214
std = tau ** (-1)
215215
std += f(eps)
216-
return f(c) - at.log(at.abs_(std)) - (x - mean) ** 2 / (2.0 * std**2)
216+
return f(c) - at.log(at.abs(std)) - (x - mean) ** 2 / (2.0 * std**2)
217217

218218

219219
def MvNormalLogp():

pymc/distributions/logprob.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from aeppl.abstract import assign_custom_measurable_outputs
2525
from aeppl.logprob import logcdf as logcdf_aeppl
2626
from aeppl.logprob import logprob as logp_aeppl
27-
from aeppl.transforms import TransformValuesOpt
27+
from aeppl.transforms import TransformValuesRewrite
2828
from aesara.graph.basic import graph_inputs, io_toposort
2929
from aesara.tensor.random.op import RandomVariable
3030
from aesara.tensor.subtensor import (
@@ -231,7 +231,7 @@ def joint_logp(
231231
if original_value_var is not None and hasattr(original_value_var.tag, "transform"):
232232
transform_map[value_var] = original_value_var.tag.transform
233233

234-
transform_opt = TransformValuesOpt(transform_map)
234+
transform_opt = TransformValuesRewrite(transform_map)
235235
temp_logp_var_dict = factorized_joint_logprob(
236236
tmp_rvs_to_values,
237237
extra_rewrites=transform_opt,

pymc/distributions/multivariate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ def moment(rv, size, n, p):
539539
n = at.shape_padright(n)
540540
mode = at.round(n * p)
541541
diff = n - at.sum(mode, axis=-1, keepdims=True)
542-
inc_bool_arr = at.abs_(diff) > 0
542+
inc_bool_arr = at.abs(diff) > 0
543543
mode = at.inc_subtensor(mode[inc_bool_arr.nonzero()], diff[inc_bool_arr.nonzero()])
544544
if not rv_size_is_none(size):
545545
output_size = at.concatenate([size, [p.shape[-1]]])

0 commit comments

Comments
 (0)