Skip to content

Commit 167355b

Browse files
ArmavicaricardoV94
authored andcommitted
Fix deprecated pytest.warns(None)
1 parent 7c90c9b commit 167355b

File tree

6 files changed

+41
-28
lines changed

6 files changed

+41
-28
lines changed

pymc/tests/test_distributions.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import functools
1515
import itertools
1616
import sys
17+
import warnings
1718

1819
import aesara
1920
import aesara.tensor as at
@@ -3288,9 +3289,9 @@ def test_no_warning_logp(self):
32883289
with pm.Model() as m:
32893290
sd_dist = pm.Exponential.dist(1, size=3)
32903291
x = pm.LKJCholeskyCov("x", n=3, eta=1, sd_dist=sd_dist)
3291-
with pytest.warns(None) as record:
3292+
with warnings.catch_warnings():
3293+
warnings.simplefilter("error")
32923294
m.logp()
3293-
assert not record
32943295

32953296
@pytest.mark.parametrize(
32963297
"sd_dist",

pymc/tests/test_math.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import warnings
16+
1517
import aesara
1618
import aesara.tensor as at
1719
import numpy as np
@@ -154,9 +156,9 @@ def test_log1mexp():
154156

155157
def test_log1mexp_numpy_no_warning():
156158
"""Assert RuntimeWarning is not raised for very small numbers"""
157-
with pytest.warns(None) as record:
159+
with warnings.catch_warnings():
160+
warnings.simplefilter("error")
158161
log1mexp_numpy(-1e-25, negative_input=True)
159-
assert not record
160162

161163

162164
def test_log1mexp_numpy_integer_input():
@@ -170,17 +172,18 @@ def test_log1mexp_deprecation_warnings():
170172
):
171173
res_pos = log1mexp_numpy(2)
172174

173-
with pytest.warns(None) as record:
175+
with warnings.catch_warnings():
176+
warnings.simplefilter("error")
174177
res_neg = log1mexp_numpy(-2, negative_input=True)
175-
assert not record
176178

177179
with pytest.warns(
178180
FutureWarning,
179181
match="pymc.math.log1mexp will expect a negative input",
180182
):
181183
res_pos_at = log1mexp(2).eval()
182184

183-
with pytest.warns(None):
185+
with warnings.catch_warnings():
186+
warnings.simplefilter("error")
184187
res_neg_at = log1mexp(-2, negative_input=True).eval()
185188

186189
assert np.isclose(res_pos, res_neg)
@@ -262,9 +265,9 @@ def test_invlogit_deprecation_warning():
262265
):
263266
res = invlogit(np.array(-750.0), 1e-5).eval()
264267

265-
with pytest.warns(None) as record:
268+
with warnings.catch_warnings():
269+
warnings.simplefilter("error")
266270
res_zero_eps = invlogit(np.array(-750.0)).eval()
267-
assert not record
268271

269272
assert np.isclose(res, res_zero_eps)
270273

@@ -280,11 +283,10 @@ def test_softmax_logsoftmax_no_warnings(aesara_function, pymc_wrapper):
280283
"""Test that wrappers for aesara functions do not issue Warnings"""
281284

282285
vector = at.vector("vector")
283-
with pytest.warns(None) as record:
286+
with pytest.warns(Warning) as record:
284287
aesara_function(vector)
285-
warnings = {warning.category for warning in record.list}
286-
assert warnings == {UserWarning, FutureWarning}
288+
assert {w.category for w in record.list} == {UserWarning, FutureWarning}
287289

288-
with pytest.warns(None) as record:
290+
with warnings.catch_warnings():
291+
warnings.simplefilter("error")
289292
pymc_wrapper(vector)
290-
assert not record

pymc/tests/test_mixture.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import warnings
16+
1517
from contextlib import ExitStack as does_not_raise
1618

1719
import aesara
@@ -655,10 +657,10 @@ def mixmixlogp(value, point):
655657
assert_allclose(priorlogp + mixmixlogpg.sum(), model.logp(test_point), rtol=rtol)
656658

657659
def test_iterable_single_component_warning(self):
658-
with pytest.warns(None) as record:
660+
with warnings.catch_warnings():
661+
warnings.simplefilter("error")
659662
Mixture.dist(w=[0.5, 0.5], comp_dists=Normal.dist(size=2))
660663
Mixture.dist(w=[0.5, 0.5], comp_dists=[Normal.dist(size=2), Normal.dist(size=2)])
661-
assert not record
662664

663665
with pytest.warns(UserWarning, match="Single component will be treated as a mixture"):
664666
Mixture.dist(w=[0.5, 0.5], comp_dists=[Normal.dist(size=2)])
@@ -1303,9 +1305,9 @@ def test_logp(self):
13031305
def test_warning(self):
13041306
with Model() as m:
13051307
comp_dists = [HalfNormal.dist(), Exponential.dist(1)]
1306-
with pytest.warns(None) as rec:
1308+
with warnings.catch_warnings():
1309+
warnings.simplefilter("error")
13071310
Mixture("mix1", w=[0.5, 0.5], comp_dists=comp_dists)
1308-
assert not rec
13091311

13101312
comp_dists = [Uniform.dist(0, 1), Uniform.dist(0, 2)]
13111313
with pytest.warns(MixtureTransformWarning):
@@ -1315,16 +1317,16 @@ def test_warning(self):
13151317
with pytest.warns(MixtureTransformWarning):
13161318
Mixture("mix3", w=[0.5, 0.5], comp_dists=comp_dists)
13171319

1318-
with pytest.warns(None) as rec:
1320+
with warnings.catch_warnings():
1321+
warnings.simplefilter("error")
13191322
Mixture("mix4", w=[0.5, 0.5], comp_dists=comp_dists, transform=None)
1320-
assert not rec
13211323

1322-
with pytest.warns(None) as rec:
1324+
with warnings.catch_warnings():
1325+
warnings.simplefilter("error")
13231326
Mixture("mix5", w=[0.5, 0.5], comp_dists=comp_dists, observed=1)
1324-
assert not rec
13251327

13261328
# Case where the appropriate default transform is None
13271329
comp_dists = [Normal.dist(), Normal.dist()]
1328-
with pytest.warns(None) as rec:
1330+
with warnings.catch_warnings():
1331+
warnings.simplefilter("error")
13291332
Mixture("mix6", w=[0.5, 0.5], comp_dists=comp_dists)
1330-
assert not rec

pymc/tests/test_model_graph.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import warnings
15+
1416
import aesara
1517
import numpy as np
1618
import pytest
@@ -271,7 +273,8 @@ class TestRadonModel(BaseModelGraphTest):
271273
model_func = radon_model
272274

273275
def test_checks_formatting(self):
274-
with pytest.warns(None):
276+
with warnings.catch_warnings():
277+
warnings.simplefilter("error")
275278
model_to_graphviz(self.model, formatting="plain")
276279
with pytest.raises(ValueError, match="Unsupported formatting"):
277280
model_to_graphviz(self.model, formatting="latex")

pymc/tests/test_sampling_jax.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
from typing import Any, Callable, Dict, Optional
24
from unittest import mock
35

@@ -113,9 +115,9 @@ def test_get_jaxified_graph():
113115
# be removed once https://github.com/aesara-devs/aesara/issues/637 is sorted.
114116
x = at.scalar("x")
115117
y = at.exp(x)
116-
with pytest.warns(None) as record:
118+
with warnings.catch_warnings():
119+
warnings.simplefilter("error")
117120
fn = get_jaxified_graph(inputs=[x], outputs=[y])
118-
assert not record
119121
assert np.isclose(fn(0), 1)
120122

121123

pymc/tests/test_shape_handling.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import warnings
16+
1517
import aesara
1618
import numpy as np
1719
import pytest
@@ -238,7 +240,8 @@ def test_param_and_batch_shape_combos(
238240
with pm.Model(coords=coords) as pmodel:
239241
mu = aesara.shared(np.random.normal(size=param_shape))
240242

241-
with pytest.warns(None):
243+
with warnings.catch_warnings():
244+
warnings.simplefilter("error")
242245
if parametrization == "implicit":
243246
rv = pm.Normal("rv", mu=mu).shape == param_shape
244247
else:

0 commit comments

Comments
 (0)