Skip to content

Commit 6d3cedf

Browse files
Merge pull request #1576 from Pearcekieser/master
Add throw_on_fail setting to check_assumptions
2 parents c9b136b + 9f6edd0 commit 6d3cedf

File tree

3 files changed

+33
-7
lines changed

3 files changed

+33
-7
lines changed

lifelines/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ class StatError(Exception):
55
pass
66

77

8+
class ProportionalHazardAssumptionError(Exception):
9+
pass
10+
11+
812
class ConvergenceError(ValueError):
913
# inherits from ValueError for backwards compatibility reasons
1014
def __init__(self, msg, original_exception=""):

lifelines/fitters/mixins.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from autograd import numpy as anp
55
import numpy as np
66
from pandas import DataFrame, Series
7+
from lifelines.exceptions import ProportionalHazardAssumptionError
78
from lifelines.statistics import proportional_hazard_test, TimeTransformers
89
from lifelines.utils import format_p_value
910
from lifelines.utils.lowess import lowess
@@ -28,6 +29,7 @@ def check_assumptions(
2829
p_value_threshold: float = 0.01,
2930
plot_n_bootstraps: int = 15,
3031
columns: Optional[List[str]] = None,
32+
raise_on_fail: bool = False,
3133
) -> None:
3234
"""
3335
Use this function to test the proportional hazards assumption. See usage example at
@@ -51,6 +53,8 @@ def check_assumptions(
5153
the function significantly.
5254
columns: list, optional
5355
specify a subset of columns to test.
56+
raise_on_fail: bool, optional
57+
throw a ``ProportionalHazardAssumptionError`` if the test fails. Default: False.
5458
5559
Returns
5660
--------
@@ -107,7 +111,7 @@ def check_assumptions(
107111

108112
for variable in self.params_.index.intersection(columns or self.params_.index):
109113
minumum_observed_p_value = test_results.summary.loc[variable, "p"].min()
110-
114+
111115
# plot is done (regardless of test result) whenever `show_plots = True`
112116
if show_plots:
113117
axes.append([])
@@ -224,9 +228,8 @@ def check_assumptions(
224228
),
225229
end="\n\n",
226230
)
227-
#################
231+
#################
228232

229-
230233
if advice and counter > 0:
231234
print(
232235
dedent(
@@ -243,6 +246,8 @@ def check_assumptions(
243246

244247
if counter == 0:
245248
print("Proportional hazard assumption looks okay.")
249+
elif raise_on_fail:
250+
raise ProportionalHazardAssumptionError()
246251
return axes
247252

248253
@property

lifelines/tests/test_estimation.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,14 @@
3434
qth_survival_time,
3535
)
3636

37-
from lifelines.exceptions import StatisticalWarning, ApproximationWarning, StatError, ConvergenceWarning, ConvergenceError
37+
from lifelines.exceptions import (
38+
ProportionalHazardAssumptionError,
39+
StatisticalWarning,
40+
ApproximationWarning,
41+
StatError,
42+
ConvergenceWarning,
43+
ConvergenceError,
44+
)
3845
from lifelines.fitters import BaseFitter, ParametricUnivariateFitter, ParametricRegressionFitter, RegressionFitter
3946
from lifelines.fitters.coxph_fitter import SemiParametricPHFitter
4047

@@ -3119,9 +3126,14 @@ def test_formulas_can_be_used_with_prediction(self, rossi, cph):
31193126

31203127
def test_formulas_handles_categories_at_inference(self, cph):
31213128
# Create a dummy dataset with some one continuous and one categorical features
3122-
df = pd.DataFrame({
3123-
'time': [1, 2, 3, 1, 2, 3], 'event': [0, 1, 1, 1, 0, 0],
3124-
'cov_cont':[0.1, 0.2, 0.3, 0.1, 0.2, 0.3], 'cov_categ': ['A', 'A', 'B', 'B', 'C', 'C']})
3129+
df = pd.DataFrame(
3130+
{
3131+
"time": [1, 2, 3, 1, 2, 3],
3132+
"event": [0, 1, 1, 1, 0, 0],
3133+
"cov_cont": [0.1, 0.2, 0.3, 0.1, 0.2, 0.3],
3134+
"cov_categ": ["A", "A", "B", "B", "C", "C"],
3135+
}
3136+
)
31253137
cph.fit(df, "time", "event", formula="cov_cont + C(cov_categ)")
31263138
cph.predict_survival_function(df.iloc[:4])
31273139

@@ -3402,6 +3414,11 @@ def test_check_assumptions(self, cph, rossi):
34023414
cph.fit(rossi, "week", "arrest")
34033415
cph.check_assumptions(rossi)
34043416

3417+
def test_check_assumptions_thows_if_raise_on_fail_enalbed(self, cph, rossi):
3418+
cph.fit(rossi, "week", "arrest")
3419+
with pytest.raises(ProportionalHazardAssumptionError):
3420+
cph.check_assumptions(rossi, p_value_threshold=0.05, raise_on_fail=True)
3421+
34053422
def test_check_assumptions_for_subset_of_columns(self, cph, rossi):
34063423
cph.fit(rossi, "week", "arrest")
34073424
cph.check_assumptions(rossi, columns=["age"])

0 commit comments

Comments
 (0)