Skip to content

Commit c21f8f8

Browse files
Merge pull request #1672 from haeunhwangbo/fix-add-at-risk-counts-numpy-2-4
Fix add_at_risk_counts for NumPy>=2.4
2 parents 839594f + 7296992 commit c21f8f8

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

lifelines/plotting.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,10 @@ def add_at_risk_counts(
540540
event_table_slice.loc[:tick, ["at_risk", "censored", "observed"]]
541541
.agg(
542542
{
543-
"at_risk": lambda x: x.tail(1).values,
543+
# `Series.tail(1).values` is a 1D array of length 1. In NumPy>=2.4,
544+
# `int(np.array([1]))` raises `TypeError: only 0-dimensional arrays can be converted to Python scalars`.
545+
# Extract a Python scalar for compatibility.
546+
"at_risk": lambda x: x.tail(1).values.item(),
544547
"censored": "sum",
545548
"observed": "sum",
546549
}
@@ -554,7 +557,7 @@ def add_at_risk_counts(
554557
)
555558
.fillna(0)
556559
)
557-
counts.extend([int(c) for c in event_table_slice.loc[rows_to_show]])
560+
counts.extend([int(np.asarray(c).item()) for c in event_table_slice.loc[rows_to_show]])
558561
else:
559562
counts.extend([0 for _ in range(n_rows)])
560563
if n_rows > 1:

lifelines/tests/test_plotting.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,23 @@ def waltons():
4444
return load_waltons()[["T", "E"]].iloc[:50]
4545

4646

47+
def test_add_at_risk_counts_is_numpy_scalar_compatible():
48+
# Regression test for https://github.com/CamDavidsonPilon/lifelines/issues/1671:
49+
# `add_at_risk_counts` previously produced 1D NumPy arrays (length 1) for "At risk" counts,
50+
# which breaks `int(c)` on NumPy>=2.4 (`TypeError: only 0-dimensional arrays can be converted to Python scalars`).
51+
matplotlib = pytest.importorskip("matplotlib")
52+
matplotlib.use("Agg", force=True)
53+
from matplotlib import pyplot as plt
54+
55+
kmf = KaplanMeierFitter().fit(
56+
np.random.exponential(10, size=(100,)),
57+
np.random.binomial(1, 0.8, size=(100,)),
58+
)
59+
ax = kmf.plot_survival_function()
60+
add_at_risk_counts(kmf, ax=ax)
61+
plt.close(ax.figure)
62+
63+
4764
@pytest.mark.skipif("DISPLAY" not in os.environ, reason="requires display")
4865
class TestPlotting:
4966
@pytest.fixture

0 commit comments

Comments
 (0)