@@ -83,45 +83,12 @@ def __init__(
83
83
max (len (self .control_strategy ), len (self .treatment_strategy )) + 1
84
84
) * self .timesteps_per_observation
85
85
self .total_time = total_time
86
- print ("PREPROCESSING" )
87
86
self .preprocess_data ()
88
- print ("PREPROCESSED" )
89
87
90
88
def add_modelling_assumptions (self ):
91
89
self .modelling_assumptions .append ("The variables in the data vary over time." )
92
90
93
- def setup_xo_t_do (self , strategy_assigned : list , strategy_followed : list , eligible : pd .Series , time : pd .Series ):
94
- """
95
- Return a binary sequence with each bit representing whether the current
96
- index is the time point at which the individual diverted from the
97
- assigned treatment strategy (and thus should be censored).
98
-
99
- :param strategy_assigned - the assigned treatment strategy
100
- :param strategy_followed - the strategy followed by the individual
101
- :param eligible - binary sequence represnting the eligibility of the individual at each time step
102
- :param time - The sequence of time steps
103
- """
104
-
105
- default = {t : (- 1 , - 1 ) for t in time .values }
106
- strategy_assigned = default | {t : (var , val ) for t , var , val in strategy_assigned }
107
- strategy_followed = default | {t : (var , val ) for t , var , val in strategy_followed }
108
-
109
- strategy_assigned = sorted ([(t , var , val ) for t , (var , val ) in strategy_assigned .items () if t in time .values ])
110
- strategy_followed = sorted ([(t , var , val ) for t , (var , val ) in strategy_followed .items () if t in time .values ])
111
-
112
- mask = (
113
- pd .Series (strategy_assigned , index = eligible .index ) != pd .Series (strategy_followed , index = eligible .index )
114
- ).astype ("boolean" )
115
- mask = mask | ~ eligible
116
- mask .reset_index (inplace = True , drop = True )
117
- false = mask .loc [mask ]
118
- if false .empty :
119
- return np .zeros (len (mask ))
120
- mask = (mask * 1 ).tolist ()
121
- cutoff = false .index [0 ] + 1
122
- return mask [:cutoff ] + ([None ] * (len (mask ) - cutoff ))
123
-
124
- def setup_xo_t_do_2 (self , individual : pd .DataFrame , strategy_assigned : list ):
91
+ def setup_xo_t_do (self , individual : pd .DataFrame , strategy_assigned : list ):
125
92
"""
126
93
Return a binary sequence with each bit representing whether the current
127
94
index is the time point at which the individual diverted from the
@@ -242,95 +209,75 @@ def preprocess_data(self):
242
209
243
210
logging .debug (" Preprocessing groups" )
244
211
245
- # new
246
- ctrl_time , ctrl_var , ctrl_val = self .control_strategy [0 ]
212
+ ctrl_time_0 , ctrl_var_0 , ctrl_val_0 = self .control_strategy [0 ]
213
+ ctrl_time , ctrl_var , ctrl_val = min (
214
+ set (map (tuple , self .control_strategy )).difference (map (tuple , self .treatment_strategy ))
215
+ )
247
216
control_group = (
248
217
living_runs .groupby ("id" , sort = False )
249
218
.filter (lambda gp : len (gp .loc [(gp ["time" ] == ctrl_time ) & (gp [ctrl_var ] == ctrl_val )]) > 0 )
219
+ .groupby ("id" , sort = False )
220
+ .filter (lambda gp : len (gp .loc [(gp ["time" ] == ctrl_time_0 ) & (gp [ctrl_var_0 ] == ctrl_val_0 )]) > 0 )
250
221
.copy ()
251
222
)
252
223
control_group ["trtrand" ] = 0
253
224
ctrl_xo_t_do_df = control_group .groupby ("id" , sort = False ).apply (
254
- self .setup_xo_t_do_2 , strategy_assigned = self .control_strategy
225
+ self .setup_xo_t_do , strategy_assigned = self .control_strategy
255
226
)
256
227
control_group ["xo_t_do" ] = ctrl_xo_t_do_df ["xo_t_do" ].values
257
228
control_group ["old_id" ] = control_group ["id" ]
258
229
# control_group["id"] = ctrl_xo_t_do_df["id"].values
259
230
control_group ["id" ] = [f"c-{ id } " for id in control_group ["id" ]]
260
231
assert not control_group ["id" ].isnull ().any (), "Null control IDs"
261
232
262
- trt_time , trt_var , trt_val = self .treatment_strategy [0 ]
233
+ trt_time_0 , trt_var_0 , trt_val_0 = self .treatment_strategy [0 ]
234
+ trt_time , trt_var , trt_val = min (
235
+ set (map (tuple , self .treatment_strategy )).difference (map (tuple , self .control_strategy ))
236
+ )
263
237
treatment_group = (
264
238
living_runs .groupby ("id" , sort = False )
265
239
.filter (lambda gp : len (gp .loc [(gp ["time" ] == trt_time ) & (gp [trt_var ] == trt_val )]) > 0 )
240
+ .groupby ("id" , sort = False )
241
+ .filter (lambda gp : len (gp .loc [(gp ["time" ] == trt_time_0 ) & (gp [trt_var_0 ] == trt_val_0 )]) > 0 )
266
242
.copy ()
267
243
)
268
244
treatment_group ["trtrand" ] = 1
269
245
trt_xo_t_do_df = treatment_group .groupby ("id" , sort = False ).apply (
270
- self .setup_xo_t_do_2 , strategy_assigned = self .treatment_strategy
246
+ self .setup_xo_t_do , strategy_assigned = self .treatment_strategy
271
247
)
272
248
treatment_group ["xo_t_do" ] = trt_xo_t_do_df ["xo_t_do" ].values
273
249
treatment_group ["old_id" ] = treatment_group ["id" ]
274
250
# treatment_group["id"] = trt_xo_t_do_df["id"].values
275
- treatment_group ["id" ] = [f"c -{ id } " for id in treatment_group ["id" ]]
251
+ treatment_group ["id" ] = [f"t -{ id } " for id in treatment_group ["id" ]]
276
252
assert not treatment_group ["id" ].isnull ().any (), "Null treatment IDs"
277
253
254
+ logger .debug (
255
+ len (control_group .groupby ("id" )),
256
+ "control individuals" ,
257
+ len (treatment_group .groupby ("id" )),
258
+ "treatment individuals" ,
259
+ )
260
+
278
261
individuals = pd .concat ([control_group , treatment_group ])
279
262
individuals = individuals .loc [
280
- individuals ["time" ]
281
- < ceil (individuals ["fault_time" ].iloc [0 ] / self .timesteps_per_observation ) * self .timesteps_per_observation
282
- ].copy ()
283
-
284
- individuals .sort_values (by = ["old_id" , "time" ]).to_csv ("/home/michael/tmp/vectorised_individuals.csv" )
285
- # end new
286
-
287
- # individuals = []
288
- #
289
- # for id, individual in tqdm(living_runs.groupby("id", sort=False)):
290
- # assert sum(individual["fault_t_do"]) <= 1, (
291
- # f"Error initialising fault_t_do for individual\n"
292
- # f"{individual[['id', 'time', self.status_column, 'fault_time', 'fault_t_do']]}\n"
293
- # f"with fault at {individual.fault_time.iloc[0]}"
294
- # )
295
- #
296
- # strategy_followed = [
297
- # [t, var, individual.loc[individual["time"] == t, var].values[0]]
298
- # for t, var, val in self.treatment_strategy
299
- # if t in individual["time"].values
300
- # ]
301
- #
302
- # # Control flow:
303
- # # Individuals that start off in both arms, need cloning (hence incrementing the ID within the if statement)
304
- # # Individuals that don't start off in either arm are left out
305
- # for inx, strategy_assigned in [(0, self.control_strategy), (1, self.treatment_strategy)]:
306
- # if (
307
- # len(strategy_followed) > 0
308
- # and strategy_assigned[0] == strategy_followed[0]
309
- # and individual.eligible.iloc[0]
310
- # ):
311
- # individual["old_id"] = individual["id"]
312
- # individual["id"] = new_id
313
- # new_id += 1
314
- # individual["trtrand"] = inx
315
- # individual["xo_t_do"] = self.setup_xo_t_do(
316
- # strategy_assigned, strategy_followed, individual["eligible"], individual["time"]
317
- # )
318
- # individuals.append(
319
- # individual.loc[
320
- # individual["time"]
321
- # < ceil(individual["fault_time"].iloc[0] / self.timesteps_per_observation)
322
- # * self.timesteps_per_observation
323
- # ].copy()
324
- # )
325
- # self.df = pd.concat(individuals)
326
- # self.df.sort_values(by=["id", "time"]).to_csv("/home/michael/tmp/iterated_individuals.csv")
263
+ (
264
+ (
265
+ individuals ["time" ]
266
+ < ceil (individuals ["fault_time" ] / self .timesteps_per_observation ) * self .timesteps_per_observation
267
+ )
268
+ & (~ individuals ["xo_t_do" ].isnull ())
269
+ )
270
+ ]
271
+
272
+ individuals .sort_values (by = ["id" , "time" ]).to_csv ("/home/michael/tmp/vectorised_individuals.csv" )
273
+
327
274
if len (individuals ) == 0 :
328
275
raise ValueError ("No individuals followed either strategy." )
329
276
self .df = individuals .loc [
330
277
individuals ["time" ]
331
278
< ceil (individuals ["fault_time" ] / self .timesteps_per_observation ) * self .timesteps_per_observation
332
279
].reset_index ()
333
- print (len (individuals ), "individuals" )
280
+ logger . debug (len (individuals . groupby ( "id" ) ), "individuals" )
334
281
335
282
if len (self .df .loc [self .df ["trtrand" ] == 0 ]) == 0 :
336
283
raise ValueError (f"No individuals began the control strategy { self .control_strategy } " )
@@ -348,13 +295,15 @@ def estimate_hazard_ratio(self):
348
295
preprocessed_data = self .df .copy ()
349
296
350
297
# Use logistic regression to predict switching given baseline covariates
351
- print ("Use logistic regression to predict switching given baseline covariates" )
298
+ logger . debug ("Use logistic regression to predict switching given baseline covariates" )
352
299
fit_bl_switch = smf .logit (self .fit_bl_switch_formula , data = self .df ).fit ()
353
300
354
301
preprocessed_data ["pxo1" ] = fit_bl_switch .predict (preprocessed_data )
355
302
356
303
# Use logistic regression to predict switching given baseline and time-updated covariates (model S12)
357
- print ("Use logistic regression to predict switching given baseline and time-updated covariates (model S12)" )
304
+ logger .debug (
305
+ "Use logistic regression to predict switching given baseline and time-updated covariates (model S12)"
306
+ )
358
307
fit_bltd_switch = smf .logit (
359
308
self .fit_bltd_switch_formula ,
360
309
data = self .df ,
@@ -368,7 +317,7 @@ def estimate_hazard_ratio(self):
368
317
369
318
# IPCW step 3: For each individual at each time, compute the inverse probability of remaining uncensored
370
319
# Estimate the probabilities of remaining 'un-switched' and hence the weights
371
- print ("Estimate the probabilities of remaining 'un-switched' and hence the weights" )
320
+ logger . debug ("Estimate the probabilities of remaining 'un-switched' and hence the weights" )
372
321
373
322
preprocessed_data ["num" ] = 1 - preprocessed_data ["pxo1" ]
374
323
preprocessed_data ["denom" ] = 1 - preprocessed_data ["pxo2" ]
@@ -397,10 +346,12 @@ def estimate_hazard_ratio(self):
397
346
f"{ preprocessed_data .loc [preprocessed_data ['tin' ] >= preprocessed_data ['tout' ], ['id' , 'time' , 'fault_time' , 'tin' , 'tout' ]]} "
398
347
)
399
348
349
+ preprocessed_data .to_csv ("/home/michael/tmp/preprocessed_data.csv" )
350
+
400
351
# IPCW step 4: Use these weights in a weighted analysis of the outcome model
401
352
# Estimate the KM graph and IPCW hazard ratio using Cox regression.
402
- print ("Estimate the KM graph and IPCW hazard ratio using Cox regression." )
403
- cox_ph = CoxPHFitter ()
353
+ logger . debug ("Estimate the KM graph and IPCW hazard ratio using Cox regression." )
354
+ cox_ph = CoxPHFitter (penalizer = 0.2 , alpha = self . alpha )
404
355
cox_ph .fit (
405
356
df = preprocessed_data ,
406
357
duration_col = "tout" ,
@@ -411,7 +362,6 @@ def estimate_hazard_ratio(self):
411
362
formula = "trtrand" ,
412
363
entry_col = "tin" ,
413
364
)
414
- print ("Estimated" )
415
365
416
366
ci_low , ci_high = [np .exp (cox_ph .confidence_intervals_ )[col ] for col in cox_ph .confidence_intervals_ .columns ]
417
367
0 commit comments