Skip to content

Commit dde2d83

Browse files
Update deprecations and set filterwarnings to error during testing
1 parent 1bd2f56 commit dde2d83

File tree

4 files changed

+10
-5
lines changed

4 files changed

+10
-5
lines changed

aehmc/metrics.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from typing import Callable, Tuple
22

33
import aesara.tensor as at
4-
import aesara.tensor.slinalg as slinalg
54
from aesara.tensor.random.utils import RandomStream
65
from aesara.tensor.shape import shape_tuple
6+
from aesara.tensor.slinalg import cholesky, solve_triangular
77
from aesara.tensor.var import TensorVariable
88

99

@@ -51,9 +51,9 @@ def gaussian_metric(
5151
dot, matmul = at.dot, lambda x, y: x * y
5252
elif inverse_mass_matrix.ndim == 2:
5353
shape = (shape_tuple(inverse_mass_matrix)[0],)
54-
tril_inv = slinalg.cholesky(inverse_mass_matrix)
54+
tril_inv = cholesky(inverse_mass_matrix)
5555
identity = at.eye(*shape)
56-
mass_matrix_sqrt = slinalg.solve_lower_triangular(tril_inv, identity)
56+
mass_matrix_sqrt = solve_triangular(tril_inv, identity, lower=True)
5757
dot, matmul = at.dot, at.dot
5858
else:
5959
raise ValueError(

aehmc/proposals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def update(initial_energy, state):
4040

4141
delta_energy = initial_energy - new_energy
4242
delta_energy = at.where(at.isnan(delta_energy), -np.inf, delta_energy)
43-
is_transition_divergent = at.abs_(delta_energy) > divergence_threshold
43+
is_transition_divergent = at.abs(delta_energy) > divergence_threshold
4444

4545
weight = delta_energy
4646
log_p_accept = at.where(

aehmc/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from aesara.graph.basic import Variable, ancestors
66
from aesara.graph.fg import FunctionGraph
77
from aesara.graph.rewriting.utils import rewrite_graph
8-
from aesara.tensor.rewriting.shape import ShapeFeature
8+
from aesara.tensor.rewriting.basic import ShapeFeature
99
from aesara.tensor.var import TensorVariable
1010

1111

setup.cfg

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ convention = numpy
1616
[tool:pytest]
1717
python_files=test*.py
1818
testpaths=tests
19+
filterwarnings=
20+
error:::aesara
21+
error:::aeppl
22+
error:::aemcmc
23+
ignore:::xarray
1924

2025
[coverage:run]
2126
omit =

0 commit comments

Comments
 (0)