diff --git a/pymc/stats/convergence.py b/pymc/stats/convergence.py index d32831c8be..7dc880ba0b 100644 --- a/pymc/stats/convergence.py +++ b/pymc/stats/convergence.py @@ -144,9 +144,15 @@ def warn_divergences(idata: arviz.InferenceData) -> list[SamplerWarning]: n_div = int(diverging.sum()) if n_div == 0: return [] + + if n_div == 1: + verb, word = "was", "divergence" + else: + verb, word = "were", "divergences" + warning = SamplerWarning( WarningType.DIVERGENCES, - f"There were {n_div} divergences after tuning. Increase `target_accept` or reparameterize.", + f"There {verb} {n_div} {word} after tuning. Increase `target_accept` or reparameterize.", "error", ) return [warning] diff --git a/tests/stats/test_convergence.py b/tests/stats/test_convergence.py index 1f7ba44791..52d5c5048c 100644 --- a/tests/stats/test_convergence.py +++ b/tests/stats/test_convergence.py @@ -16,19 +16,27 @@ import arviz import numpy as np +import pytest from pymc.stats import convergence -def test_warn_divergences(): +@pytest.mark.parametrize( + "diverging, expected_phrase", + [ + pytest.param([1, 0, 1, 0], "were 2 divergences after tuning", id="plural"), + pytest.param([1, 0, 0, 0], "was 1 divergence after tuning", id="singular"), + ], +) +def test_warn_divergences(diverging, expected_phrase): idata = arviz.from_dict( sample_stats={ - "diverging": np.array([[1, 0, 1, 0], [0, 0, 0, 0]]).astype(bool), + "diverging": np.array([diverging, [0, 0, 0, 0]]).astype(bool), } ) warns = convergence.warn_divergences(idata) assert len(warns) == 1 - assert "2 divergences after tuning" in warns[0].message + assert expected_phrase in warns[0].message def test_warn_treedepth():