Skip to content

Commit fa26300

Browse files
Merge pull request #36 from loft-br/feat/refactor-extrapolation
Feat/refactor extrapolation
2 parents 1ad72c4 + c7e0010 commit fa26300

File tree

11 files changed

+117
-2544
lines changed

11 files changed

+117
-2544
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ from xgbse.extrapolation import extrapolate_constant_risk
217217
survival = bootstrap_estimator.predict(X_valid)
218218

219219
# extrapolating
220-
survival_ext = extrapolate_constant_risk(survival, 450, 11)
220+
survival_ext = extrapolate_constant_risk(survival, 450, 15)
221221
```
222222

223223
<img src="img/extrapolation.png">
@@ -407,7 +407,7 @@ To cite this repository:
407407
author = {Davi Vieira and Gabriel Gimenez and Guilherme Marmerola and Vitor Estima},
408408
title = {XGBoost Survival Embeddings: improving statistical properties of XGBoost survival analysis implementation},
409409
url = {http://github.com/loft-br/xgboost-survival-embeddings},
410-
version = {0.2.1},
410+
version = {0.2.2},
411411
year = {2021},
412412
}
413413
```

docs/examples/extrapolation_example.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ Notice that this predicted survival curve does not end at zero (cure fraction du
383383
from xgbse.extrapolation import extrapolate_constant_risk
384384

385385
# extrapolating predicted survival
386-
survival_ext = extrapolate_constant_risk(survival, 450, 11)
386+
survival_ext = extrapolate_constant_risk(survival, 450, 15)
387387
survival_ext.head()
388388
```
389389

examples/extrapolation_example.ipynb

Lines changed: 75 additions & 2508 deletions
Large diffs are not rendered by default.

img/extrapolation.png

-58.5 KB
Loading

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
setuptools.setup(
3636
name="xgbse",
37-
version="0.2.1",
37+
version="0.2.2",
3838
author="Loft Data Science Team",
3939
author_email="bandits@loft.com.br",
4040
description="Improving XGBoost survival analysis with embeddings and debiased estimators",

tests/test_extrapolation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,11 @@
4242
)
4343

4444
preds = xgbse_model.predict(X_test)
45+
interval = 10
46+
final_time = max(time_bins) + 1000
4547
n_windows = 100
46-
final_time = max(T_train) + 1000
47-
preds_ext = extrapolate_constant_risk(preds, final_time=final_time, n_windows=n_windows)
48+
49+
preds_ext = extrapolate_constant_risk(preds, final_time=final_time, intervals=interval)
4850

4951

5052
def extrapolation_shape():

xgbse/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from ._meta import XGBSEBootstrapEstimator
77

88

9-
__version__ = "0.2.1"
9+
__version__ = "0.2.2"
1010

1111
__all__ = [
1212
"XGBSEDebiasedBCE",

xgbse/_debiased_bce.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
# lib utils
1111
from xgbse._base import XGBSEBaseEstimator, DummyLogisticRegression
12-
from xgbse.converters import convert_data_to_xgb_format, convert_y
12+
from xgbse.converters import convert_data_to_xgb_format, convert_y, hazard_to_survival
1313

1414
# at which percentiles will the KM predict
1515
from xgbse.non_parametric import get_time_bins, calculate_interval_failures
@@ -346,9 +346,7 @@ def _predict_from_lr_list(self, lr_estimators, leaves_encoded, time_bins):
346346

347347
# converting these interval predictions
348348
# to cumulative survival curve
349-
preds = (1 - preds).cumprod(axis=1)
350-
351-
return preds
349+
return hazard_to_survival(preds)
352350

353351
def predict(self, X, return_interval_probs=False):
354352
"""

xgbse/converters.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,16 @@ def build_xgb_cox_dmatrix(X, T, E):
112112
target = np.where(E, T, -T)
113113

114114
return xgb.DMatrix(X, label=target)
115+
116+
117+
def hazard_to_survival(interval):
118+
"""Convert hazards (interval probabilities of event) into survival curve
119+
120+
Args:
121+
interval ([pd.DataFrame, np.array]): hazards (interval probabilities of event)
122+
usually result of predict or result from _get_point_probs_from_survival
123+
124+
Returns:
125+
[pd.DataFrame, np.array]: survival curve
126+
"""
127+
return (1 - interval).cumprod(axis=1)

xgbse/extrapolation.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import numpy as np
22
import pandas as pd
33
from xgbse.non_parametric import _get_conditional_probs_from_survival
4+
from xgbse.converters import hazard_to_survival
45

56

6-
def extrapolate_constant_risk(survival, final_time, n_windows, lags=-1):
7+
def extrapolate_constant_risk(survival, final_time, intervals, lags=-1):
78
"""
89
Extrapolate a survival curve assuming constant risk.
910
@@ -13,7 +14,7 @@ def extrapolate_constant_risk(survival, final_time, n_windows, lags=-1):
1314
1415
final_time (Float): Final time for extrapolation
1516
16-
n_windows (Int): Number of time windows to compute from last time window in survival to final_time
17+
intervals (Int): Time in each interval between last time in survival dataframe and final time
1718
1819
lags (Int): Lags to compute constant risk.
1920
if negative, will use the last "lags" values
@@ -24,28 +25,23 @@ def extrapolate_constant_risk(survival, final_time, n_windows, lags=-1):
2425
pd.DataFrame: Survival dataset with appended extrapolated windows
2526
"""
2627

27-
# calculating conditionals and risk at each time window
28-
conditionals = _get_conditional_probs_from_survival(survival)
29-
window_risk = 1 - conditionals
28+
last_time = survival.columns[-1]
29+
# creating windows for extrapolation
30+
# here we sum intervals in times to exclude the last time, that already is in surv dataframe and
31+
# to include final time in resulting dataframe
32+
extrap_windows = np.arange(last_time + intervals, final_time + intervals, intervals)
3033

31-
# calculating window sizes
32-
time_bins = window_risk.columns.to_series()
33-
window_sizes = time_bins - time_bins.shift(1).fillna(0)
34+
# calculating conditionals and hazard at each time window
35+
hazards = _get_conditional_probs_from_survival(survival)
3436

35-
# using window sizes to calculate risk per unit time and average risk
36-
risk_per_unit_time = np.power(window_risk, 1 / window_sizes)
37-
average_risk = risk_per_unit_time.iloc[:, lags:].mean(axis=1)
37+
# calculating avg hazard for desired lags
38+
constant_haz = hazards.values[:, lags:].mean(axis=1).reshape(-1, 1)
3839

39-
# creating windows for extrapolation
40-
last_time = survival.columns[-1]
41-
extrap_windows = np.linspace(last_time, final_time, n_windows) - last_time
40+
# repeat hazard for n_windows required
41+
constant_haz = np.tile(constant_haz, len(extrap_windows))
4242

43-
# loop for extrapolated windows
44-
for delta_t in extrap_windows:
43+
constant_haz = pd.DataFrame(constant_haz, columns=extrap_windows)
4544

46-
# running constant risk extrapolation
47-
extrap_survival = np.power(average_risk, delta_t) * survival.iloc[:, -1]
48-
extrap_survival = pd.Series(extrap_survival, name=last_time + delta_t)
49-
survival = pd.concat([survival, extrap_survival], axis=1)
45+
hazards = pd.concat([hazards, constant_haz], axis=1)
5046

51-
return survival
47+
return hazard_to_survival(hazards)

0 commit comments

Comments
 (0)