From 5a81852534766ec5a9ffd64f4d8f8b8549bda849 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Thu, 12 Jun 2025 13:20:19 -0400 Subject: [PATCH] change verb and word based on number of divergences --- pymc/stats/convergence.py | 8 +++++++- tests/stats/test_convergence.py | 14 +++++++++++--- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/pymc/stats/convergence.py b/pymc/stats/convergence.py index d32831c8be..7dc880ba0b 100644 --- a/pymc/stats/convergence.py +++ b/pymc/stats/convergence.py @@ -144,9 +144,15 @@ def warn_divergences(idata: arviz.InferenceData) -> list[SamplerWarning]: n_div = int(diverging.sum()) if n_div == 0: return [] + + if n_div == 1: + verb, word = "was", "divergence" + else: + verb, word = "were", "divergences" + warning = SamplerWarning( WarningType.DIVERGENCES, - f"There were {n_div} divergences after tuning. Increase `target_accept` or reparameterize.", + f"There {verb} {n_div} {word} after tuning. Increase `target_accept` or reparameterize.", "error", ) return [warning] diff --git a/tests/stats/test_convergence.py b/tests/stats/test_convergence.py index 1f7ba44791..52d5c5048c 100644 --- a/tests/stats/test_convergence.py +++ b/tests/stats/test_convergence.py @@ -16,19 +16,27 @@ import arviz import numpy as np +import pytest from pymc.stats import convergence -def test_warn_divergences(): +@pytest.mark.parametrize( + "diverging, expected_phrase", + [ + pytest.param([1, 0, 1, 0], "were 2 divergences after tuning", id="plural"), + pytest.param([1, 0, 0, 0], "was 1 divergence after tuning", id="singular"), + ], +) +def test_warn_divergences(diverging, expected_phrase): idata = arviz.from_dict( sample_stats={ - "diverging": np.array([[1, 0, 1, 0], [0, 0, 0, 0]]).astype(bool), + "diverging": np.array([diverging, [0, 0, 0, 0]]).astype(bool), } ) warns = convergence.warn_divergences(idata) assert len(warns) == 1 - assert "2 divergences after tuning" in warns[0].message + assert expected_phrase in warns[0].message def test_warn_treedepth():