2
2
3
3
import logging
4
4
from math import ceil
5
+ from typing import Any
5
6
6
7
import numpy as np
7
8
import pandas as pd
8
9
import statsmodels .formula .api as smf
9
10
from lifelines import CoxPHFitter
10
11
11
- from causal_testing .specification .capabilities import TreatmentSequence , Capability
12
12
from causal_testing .estimation .abstract_estimator import Estimator
13
13
14
14
logger = logging .getLogger (__name__ )
15
15
16
16
17
17
class IPCWEstimator (Estimator ):
18
18
"""
19
- Class to perform inverse probability of censoring weighting (IPCW) estimation
19
+ Class to perform Inverse Probability of Censoring Weighting (IPCW) estimation
20
20
for sequences of treatments over time-varying data.
21
21
"""
22
22
@@ -25,37 +25,57 @@ class IPCWEstimator(Estimator):
25
25
def __init__ (
26
26
self ,
27
27
df : pd .DataFrame ,
28
- timesteps_per_intervention : int ,
29
- control_strategy : TreatmentSequence ,
30
- treatment_strategy : TreatmentSequence ,
28
+ timesteps_per_observation : int ,
29
+ control_strategy : list [ tuple [ int , str , Any ]] ,
30
+ treatment_strategy : list [ tuple [ int , str , Any ]] ,
31
31
outcome : str ,
32
- fault_column : str ,
32
+ status_column : str ,
33
33
fit_bl_switch_formula : str ,
34
34
fit_bltd_switch_formula : str ,
35
35
eligibility = None ,
36
36
alpha : float = 0.05 ,
37
+ total_time : float = None ,
37
38
):
39
+ """
40
+ Initialise IPCWEstimator.
41
+
42
+ :param: df: Input DataFrame containing time-varying data.
43
+ :param: timesteps_per_observation: Number of timesteps per observation.
44
+ :param: control_strategy: The control strategy, with entries of the form (timestep, variable, value).
45
+ :param: treatment_strategy: The treatment strategy, with entries of the form (timestep, variable, value).
46
+ :param: outcome: Name of the outcome column in the DataFrame.
47
+ :param: status_column: Name of the status column in the DataFrame, which should be True for operating normally, False for a fault.
48
+ :param: fit_bl_switch_formula: Formula for fitting the baseline switch model.
49
+ :param: fit_bltd_switch_formula: Formula for fitting the baseline time-dependent switch model.
50
+ :param: eligibility: Function to determine eligibility for treatment. Defaults to None for "always eligible".
51
+ :param: alpha: Significance level for hypothesis testing. Defaults to 0.05.
52
+ :param: total_time: Total time for the analysis. Defaults to one plus the length of of the strategy (control or
53
+ treatment) with the most elements multiplied by `timesteps_per_observation`.
54
+ """
38
55
super ().__init__ (
39
- [c . variable for c in treatment_strategy . capabilities ],
40
- [c . value for c in treatment_strategy . capabilities ],
41
- [c . value for c in control_strategy . capabilities ],
56
+ [var for _ , var , _ in treatment_strategy ],
57
+ [val for _ , _ , val in treatment_strategy ],
58
+ [val for _ , _ , val in control_strategy ],
42
59
None ,
43
60
outcome ,
44
61
df ,
45
62
None ,
46
63
alpha = alpha ,
47
64
query = "" ,
48
65
)
49
- self .timesteps_per_intervention = timesteps_per_intervention
66
+ self .timesteps_per_observation = timesteps_per_observation
50
67
self .control_strategy = control_strategy
51
68
self .treatment_strategy = treatment_strategy
52
69
self .outcome = outcome
53
- self .fault_column = fault_column
54
- self .timesteps_per_intervention = timesteps_per_intervention
70
+ self .status_column = status_column
55
71
self .fit_bl_switch_formula = fit_bl_switch_formula
56
72
self .fit_bltd_switch_formula = fit_bltd_switch_formula
57
73
self .eligibility = eligibility
58
74
self .df = df
75
+ if total_time is None :
76
+ self .total_time = (
77
+ max (len (self .control_strategy ), len (self .treatment_strategy )) + 1
78
+ ) * self .timesteps_per_observation
59
79
self .preprocess_data ()
60
80
61
81
def add_modelling_assumptions (self ):
@@ -92,16 +112,16 @@ def setup_fault_t_do(self, individual: pd.DataFrame):
92
112
index is the time point at which the event of interest (i.e. a fault)
93
113
occurred.
94
114
"""
95
- fault = individual [~ individual [self .fault_column ]]
115
+ fault = individual [~ individual [self .status_column ]]
96
116
fault_t_do = pd .Series (np .zeros (len (individual )), index = individual .index )
97
117
98
118
if not fault .empty :
99
119
fault_time = individual ["time" ].loc [fault .index [0 ]]
100
120
# Ceiling to nearest observation point
101
- fault_time = ceil (fault_time / self .timesteps_per_intervention ) * self .timesteps_per_intervention
121
+ fault_time = ceil (fault_time / self .timesteps_per_observation ) * self .timesteps_per_observation
102
122
# Set the correct observation point to be the fault time of doing (fault_t_do)
103
123
observations = individual .loc [
104
- (individual ["time" ] % self .timesteps_per_intervention == 0 ) & (individual ["time" ] < fault_time )
124
+ (individual ["time" ] % self .timesteps_per_observation == 0 ) & (individual ["time" ] < fault_time )
105
125
]
106
126
if not observations .empty :
107
127
fault_t_do .loc [observations .index [0 ]] = 1
@@ -113,11 +133,11 @@ def setup_fault_time(self, individual: pd.DataFrame, perturbation: float = -0.00
113
133
"""
114
134
Return the time at which the event of interest (i.e. a fault) occurred.
115
135
"""
116
- fault = individual [~ individual [self .fault_column ]]
136
+ fault = individual [~ individual [self .status_column ]]
117
137
fault_time = (
118
138
individual ["time" ].loc [fault .index [0 ]]
119
139
if not fault .empty
120
- else (individual ["time" ].max () + self .timesteps_per_intervention )
140
+ else (individual ["time" ].max () + self .timesteps_per_observation )
121
141
)
122
142
return pd .DataFrame ({"fault_time" : np .repeat (fault_time + perturbation , len (individual ))})
123
143
@@ -130,15 +150,14 @@ def preprocess_data(self):
130
150
self .df ["eligible" ] = self .df .eval (self .eligibility ) if self .eligibility is not None else True
131
151
132
152
# when did a fault occur?
133
- self .df ["fault_time" ] = self .df .groupby ("id" )[[self .fault_column , "time" ]].apply (self .setup_fault_time ).values
153
+ self .df ["fault_time" ] = self .df .groupby ("id" )[[self .status_column , "time" ]].apply (self .setup_fault_time ).values
134
154
self .df ["fault_t_do" ] = (
135
- self .df .groupby ("id" )[["id" , "time" , self .fault_column ]].apply (self .setup_fault_t_do ).values
155
+ self .df .groupby ("id" )[["id" , "time" , self .status_column ]].apply (self .setup_fault_t_do ).values
136
156
)
137
157
assert not pd .isnull (self .df ["fault_time" ]).any ()
138
158
139
159
living_runs = self .df .query ("fault_time > 0" ).loc [
140
- (self .df ["time" ] % self .timesteps_per_intervention == 0 )
141
- & (self .df ["time" ] <= self .control_strategy .total_time ())
160
+ (self .df ["time" ] % self .timesteps_per_observation == 0 ) & (self .df ["time" ] <= self .total_time )
142
161
]
143
162
144
163
individuals = []
@@ -152,25 +171,20 @@ def preprocess_data(self):
152
171
)
153
172
154
173
strategy_followed = [
155
- Capability (
156
- c .variable ,
157
- individual .loc [individual ["time" ] == c .start_time , c .variable ].values [0 ],
158
- c .start_time ,
159
- c .end_time ,
160
- )
161
- for c in self .treatment_strategy .capabilities
174
+ (t , var , individual .loc [individual ["time" ] == t , var ].values [0 ])
175
+ for t , var , val in self .treatment_strategy
162
176
]
163
177
164
178
# Control flow:
165
179
# Individuals that start off in both arms, need cloning (hence incrementing the ID within the if statement)
166
180
# Individuals that don't start off in either arm are left out
167
181
for inx , strategy_assigned in [(0 , self .control_strategy ), (1 , self .treatment_strategy )]:
168
- if strategy_assigned . capabilities [0 ] == strategy_followed [0 ] and individual .eligible .iloc [0 ]:
182
+ if strategy_assigned [0 ] == strategy_followed [0 ] and individual .eligible .iloc [0 ]:
169
183
individual ["id" ] = new_id
170
184
new_id += 1
171
185
individual ["trtrand" ] = inx
172
186
individual ["xo_t_do" ] = self .setup_xo_t_do (
173
- strategy_assigned . capabilities , strategy_followed , individual ["eligible" ]
187
+ strategy_assigned , strategy_followed , individual ["eligible" ]
174
188
)
175
189
individuals .append (individual .loc [individual ["time" ] <= individual ["fault_time" ]].copy ())
176
190
if len (individuals ) == 0 :
@@ -222,7 +236,7 @@ def estimate_hazard_ratio(self):
222
236
223
237
preprocessed_data ["tin" ] = preprocessed_data ["time" ]
224
238
preprocessed_data ["tout" ] = pd .concat (
225
- [(preprocessed_data ["time" ] + self .timesteps_per_intervention ), preprocessed_data ["fault_time" ]],
239
+ [(preprocessed_data ["time" ] + self .timesteps_per_observation ), preprocessed_data ["fault_time" ]],
226
240
axis = 1 ,
227
241
).min (axis = 1 )
228
242
0 commit comments