11import numpy as np
22import pandas as pd
33from xgbse .non_parametric import _get_conditional_probs_from_survival
4+ from xgbse .converters import hazard_to_survival
45
56
6- def extrapolate_constant_risk (survival , final_time , n_windows , lags = - 1 ):
7+ def extrapolate_constant_risk (survival , final_time , intervals , lags = - 1 ):
78 """
89 Extrapolate a survival curve assuming constant risk.
910
@@ -13,7 +14,7 @@ def extrapolate_constant_risk(survival, final_time, n_windows, lags=-1):
1314
1415 final_time (Float): Final time for extrapolation
1516
16- n_windows (Int): Number of time windows to compute from last time window in survival to final_time
17+ intervals (Int): Time in each interval between last time in survival dataframe and final time
1718
1819 lags (Int): Lags to compute constant risk.
1920 if negative, will use the last "lags" values
@@ -24,28 +25,23 @@ def extrapolate_constant_risk(survival, final_time, n_windows, lags=-1):
2425 pd.DataFrame: Survival dataset with appended extrapolated windows
2526 """
2627
27- # calculating conditionals and risk at each time window
28- conditionals = _get_conditional_probs_from_survival (survival )
29- window_risk = 1 - conditionals
28+ last_time = survival .columns [- 1 ]
29+ # creating windows for extrapolation
30+ # here we sum intervals in times to exclude the last time, that already is in surv dataframe and
31+ # to include final time in resulting dataframe
32+ extrap_windows = np .arange (last_time + intervals , final_time + intervals , intervals )
3033
31- # calculating window sizes
32- time_bins = window_risk .columns .to_series ()
33- window_sizes = time_bins - time_bins .shift (1 ).fillna (0 )
34+ # calculating conditionals and hazard at each time window
35+ hazards = _get_conditional_probs_from_survival (survival )
3436
35- # using window sizes to calculate risk per unit time and average risk
36- risk_per_unit_time = np .power (window_risk , 1 / window_sizes )
37- average_risk = risk_per_unit_time .iloc [:, lags :].mean (axis = 1 )
37+ # calculating avg hazard for desired lags
38+ constant_haz = hazards .values [:, lags :].mean (axis = 1 ).reshape (- 1 , 1 )
3839
39- # creating windows for extrapolation
40- last_time = survival .columns [- 1 ]
41- extrap_windows = np .linspace (last_time , final_time , n_windows ) - last_time
40+ # repeat hazard for n_windows required
41+ constant_haz = np .tile (constant_haz , len (extrap_windows ))
4242
43- # loop for extrapolated windows
44- for delta_t in extrap_windows :
43+ constant_haz = pd .DataFrame (constant_haz , columns = extrap_windows )
4544
46- # running constant risk extrapolation
47- extrap_survival = np .power (average_risk , delta_t ) * survival .iloc [:, - 1 ]
48- extrap_survival = pd .Series (extrap_survival , name = last_time + delta_t )
49- survival = pd .concat ([survival , extrap_survival ], axis = 1 )
45+ hazards = pd .concat ([hazards , constant_haz ], axis = 1 )
5046
51- return survival
47+ return hazard_to_survival ( hazards )
0 commit comments