Skip to content

Commit 024c7e9

Browse files
committed
feat #6686: expect invalid value warn
1 parent db91647 commit 024c7e9

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

tests/distributions/test_censored.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,17 +131,19 @@ def test_censored_logcdf_continuous(self):
131131
norm = pm.Normal.dist(0, 1)
132132
eval_points = np.array([-np.inf, -2, -1, 0, 1, 2, np.inf])
133133

134+
match_str = "divide by zero encountered in log|invalid value encountered in subtract"
135+
134136
# No censoring
135137
censored_norm = pm.Censored.dist(norm, lower=None, upper=None)
136-
with pytest.warns(RuntimeWarning, match="divide by zero encountered in log"):
138+
with pytest.warns(RuntimeWarning, match=match_str):
137139
censored_eval = logcdf(censored_norm, eval_points).eval()
138-
with pytest.warns(RuntimeWarning, match="divide by zero encountered in log"):
140+
with pytest.warns(RuntimeWarning, match=match_str):
139141
norm_eval = logcdf(norm, eval_points).eval()
140142
np.testing.assert_allclose(censored_eval, norm_eval)
141143

142144
# Left censoring
143145
censored_norm = pm.Censored.dist(norm, lower=-1, upper=None)
144-
with pytest.warns(RuntimeWarning, match="divide by zero encountered in log"):
146+
with pytest.warns(RuntimeWarning, match=match_str):
145147
censored_eval = logcdf(censored_norm, eval_points).eval()
146148
np.testing.assert_allclose(
147149
censored_eval,
@@ -151,7 +153,7 @@ def test_censored_logcdf_continuous(self):
151153

152154
# Right censoring
153155
censored_norm = pm.Censored.dist(norm, lower=None, upper=1)
154-
with pytest.warns(RuntimeWarning, match="divide by zero encountered in log"):
156+
with pytest.warns(RuntimeWarning, match=match_str):
155157
censored_eval = logcdf(censored_norm, eval_points).eval()
156158
np.testing.assert_allclose(
157159
censored_eval,
@@ -161,7 +163,7 @@ def test_censored_logcdf_continuous(self):
161163

162164
# Interval censoring
163165
censored_norm = pm.Censored.dist(norm, lower=-1, upper=1)
164-
with pytest.warns(RuntimeWarning, match="divide by zero encountered in log"):
166+
with pytest.warns(RuntimeWarning, match=match_str):
165167
censored_eval = logcdf(censored_norm, eval_points).eval()
166168
np.testing.assert_allclose(
167169
censored_eval,

0 commit comments

Comments
 (0)