Skip to content

Commit 349c948

Browse files
authored
change verb and word based on number of divergences (#7817)
1 parent e6d3390 commit 349c948

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

pymc/stats/convergence.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,15 @@ def warn_divergences(idata: arviz.InferenceData) -> list[SamplerWarning]:
144144
n_div = int(diverging.sum())
145145
if n_div == 0:
146146
return []
147+
148+
if n_div == 1:
149+
verb, word = "was", "divergence"
150+
else:
151+
verb, word = "were", "divergences"
152+
147153
warning = SamplerWarning(
148154
WarningType.DIVERGENCES,
149-
f"There were {n_div} divergences after tuning. Increase `target_accept` or reparameterize.",
155+
f"There {verb} {n_div} {word} after tuning. Increase `target_accept` or reparameterize.",
150156
"error",
151157
)
152158
return [warning]

tests/stats/test_convergence.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,27 @@
1616

1717
import arviz
1818
import numpy as np
19+
import pytest
1920

2021
from pymc.stats import convergence
2122

2223

23-
def test_warn_divergences():
24+
@pytest.mark.parametrize(
25+
"diverging, expected_phrase",
26+
[
27+
pytest.param([1, 0, 1, 0], "were 2 divergences after tuning", id="plural"),
28+
pytest.param([1, 0, 0, 0], "was 1 divergence after tuning", id="singular"),
29+
],
30+
)
31+
def test_warn_divergences(diverging, expected_phrase):
2432
idata = arviz.from_dict(
2533
sample_stats={
26-
"diverging": np.array([[1, 0, 1, 0], [0, 0, 0, 0]]).astype(bool),
34+
"diverging": np.array([diverging, [0, 0, 0, 0]]).astype(bool),
2735
}
2836
)
2937
warns = convergence.warn_divergences(idata)
3038
assert len(warns) == 1
31-
assert "2 divergences after tuning" in warns[0].message
39+
assert expected_phrase in warns[0].message
3240

3341

3442
def test_warn_treedepth():

0 commit comments

Comments
 (0)