@@ -608,17 +608,18 @@ class IPCWEstimator(Estimator):
608
608
for sequences of treatments over time-varying data.
609
609
"""
610
610
611
+ # pylint: disable=too-many-arguments
612
+ # pylint: disable=too-many-instance-attributes
611
613
def __init__ (
612
614
self ,
613
615
df : pd .DataFrame ,
614
616
timesteps_per_intervention : int ,
615
617
control_strategy : TreatmentSequence ,
616
618
treatment_strategy : TreatmentSequence ,
617
619
outcome : str ,
618
- min : float ,
619
- max : float ,
620
- fitBLswitch_formula : str ,
621
- fitBLTDswitch_formula : str ,
620
+ fault_column : str ,
621
+ fit_bl_switch_formula : str ,
622
+ fit_bltd_switch_formula : str ,
622
623
eligibility = None ,
623
624
alpha : float = 0.05 ,
624
625
query : str = "" ,
@@ -638,13 +639,12 @@ def __init__(
638
639
self .control_strategy = control_strategy
639
640
self .treatment_strategy = treatment_strategy
640
641
self .outcome = outcome
641
- self .min = min
642
- self .max = max
642
+ self .fault_column = fault_column
643
643
self .timesteps_per_intervention = timesteps_per_intervention
644
- self .fitBLswitch_formula = fitBLswitch_formula
645
- self .fitBLTDswitch_formula = fitBLTDswitch_formula
644
+ self .fit_bl_switch_formula = fit_bl_switch_formula
645
+ self .fit_bltd_switch_formula = fit_bltd_switch_formula
646
646
self .eligibility = eligibility
647
- self .df = self . preprocess_data ( df )
647
+ self .df = df
648
648
649
649
def add_modelling_assumptions (self ):
650
650
self .modelling_assumptions .append ("The variables in the data vary over time." )
@@ -667,21 +667,20 @@ def setup_xo_t_do(self, strategy_assigned: list, strategy_followed: list, eligib
667
667
).astype ("boolean" )
668
668
mask = mask | ~ eligible
669
669
mask .reset_index (inplace = True , drop = True )
670
- false = mask .loc [mask == True ]
670
+ false = mask .loc [mask ]
671
671
if false .empty :
672
672
return np .zeros (len (mask ))
673
- else :
674
- mask = (mask * 1 ).tolist ()
675
- cutoff = false .index [0 ] + 1
676
- return mask [:cutoff ] + ([None ] * (len (mask ) - cutoff ))
673
+ mask = (mask * 1 ).tolist ()
674
+ cutoff = false .index [0 ] + 1
675
+ return mask [:cutoff ] + ([None ] * (len (mask ) - cutoff ))
677
676
678
677
def setup_fault_t_do (self , individual : pd .DataFrame ):
679
678
"""
680
679
Return a binary sequence with each bit representing whether the current
681
680
index is the time point at which the event of interest (i.e. a fault)
682
681
occurred.
683
682
"""
684
- fault = individual [individual ["within_safe_range" ] == False ]
683
+ fault = individual [~ individual [self . fault_column ] ]
685
684
fault_t_do = pd .Series (np .zeros (len (individual )), index = individual .index )
686
685
687
686
if not fault .empty :
@@ -702,39 +701,43 @@ def setup_fault_time(self, individual, perturbation=-0.001):
702
701
"""
703
702
Return the time at which the event of interest (i.e. a fault) occurred.
704
703
"""
705
- fault = individual [individual ["within_safe_range" ] == False ]
704
+ fault = individual [~ individual [self . fault_column ] ]
706
705
fault_time = (
707
706
individual ["time" ].loc [fault .index [0 ]]
708
707
if not fault .empty
709
708
else (individual ["time" ].max () + self .timesteps_per_intervention )
710
709
)
711
710
return pd .DataFrame ({"fault_time" : np .repeat (fault_time + perturbation , len (individual ))})
712
711
713
- def preprocess_data (self , df ):
712
+ def preprocess_data (self ):
714
713
"""
715
714
Set up the treatment-specific columns in the data that are needed to estimate the hazard ratio.
716
715
"""
717
- df ["trtrand" ] = None # treatment/control arm
718
- df ["xo_t_do" ] = None # did the individual deviate from the treatment of interest here?
719
- df ["eligible" ] = df .eval (self .eligibility ) if self .eligibility is not None else True
716
+ self . df ["trtrand" ] = None # treatment/control arm
717
+ self . df ["xo_t_do" ] = None # did the individual deviate from the treatment of interest here?
718
+ self . df ["eligible" ] = self . df .eval (self .eligibility ) if self .eligibility is not None else True
720
719
721
720
# when did a fault occur?
722
- df ["within_safe_range" ] = df [self .outcome ].between (self .min , self .max )
723
- df ["fault_time" ] = df .groupby ("id" )[["within_safe_range" , "time" ]].apply (self .setup_fault_time ).values
724
- df ["fault_t_do" ] = df .groupby ("id" )[["id" , "time" , "within_safe_range" ]].apply (self .setup_fault_t_do ).values
725
- assert not pd .isnull (df ["fault_time" ]).any ()
721
+ self .df ["fault_time" ] = self .df .groupby ("id" )[[self .fault_column , "time" ]].apply (self .setup_fault_time ).values
722
+ self .df ["fault_t_do" ] = (
723
+ self .df .groupby ("id" )[["id" , "time" , self .fault_column ]].apply (self .setup_fault_t_do ).values
724
+ )
725
+ assert not pd .isnull (self .df ["fault_time" ]).any ()
726
726
727
- living_runs = df .query ("fault_time > 0" ).loc [
728
- (df ["time" ] % self .timesteps_per_intervention == 0 ) & (df ["time" ] <= self .control_strategy .total_time ())
727
+ living_runs = self .df .query ("fault_time > 0" ).loc [
728
+ (self .df ["time" ] % self .timesteps_per_intervention == 0 )
729
+ & (self .df ["time" ] <= self .control_strategy .total_time ())
729
730
]
730
731
731
732
individuals = []
732
733
new_id = 0
733
734
logging .debug (" Preprocessing groups" )
734
- for id , individual in living_runs .groupby ("id" ):
735
- assert (
736
- sum (individual ["fault_t_do" ]) <= 1
737
- ), f"Error initialising fault_t_do for individual\n { individual [['id' , 'time' , 'fault_time' , 'fault_t_do' ]]} \n with fault at { individual .fault_time .iloc [0 ]} "
735
+ for _ , individual in living_runs .groupby ("id" ):
736
+ assert sum (individual ["fault_t_do" ]) <= 1 , (
737
+ f"Error initialising fault_t_do for individual\n "
738
+ f"{ individual [['id' , 'time' , 'fault_time' , 'fault_t_do' ]]} \n "
739
+ "with fault at {individual.fault_time.iloc[0]}"
740
+ )
738
741
739
742
strategy_followed = [
740
743
Capability (
@@ -761,59 +764,67 @@ def preprocess_data(self, df):
761
764
if len (individuals ) == 0 :
762
765
raise ValueError ("No individuals followed either strategy." )
763
766
764
- novCEA = pd .concat (individuals )
767
+ return pd .concat (individuals )
765
768
766
- if novCEA ["fault_t_do" ].sum () == 0 :
769
+ def estimate_hazard_ratio (self ):
770
+ """
771
+ Estimate the hazard ratio.
772
+ """
773
+
774
+ preprocessed_data = self .preprocess_data ()
775
+
776
+ if preprocessed_data ["fault_t_do" ].sum () == 0 :
767
777
raise ValueError ("No recorded faults" )
768
778
769
779
# Use logistic regression to predict switching given baseline covariates
770
- fitBLswitch = smf .logit (self .fitBLswitch_formula , data = novCEA ).fit ()
780
+ fit_bl_switch = smf .logit (self .fit_bl_switch_formula , data = preprocessed_data ).fit ()
771
781
772
- novCEA ["pxo1" ] = fitBLswitch .predict (novCEA )
782
+ preprocessed_data ["pxo1" ] = fit_bl_switch .predict (preprocessed_data )
773
783
774
784
# Use logistic regression to predict switching given baseline and time-updated covariates (model S12)
775
- fitBLTDswitch = smf .logit (
776
- self .fitBLTDswitch_formula ,
777
- data = novCEA ,
785
+ fit_bltd_switch = smf .logit (
786
+ self .fit_bltd_switch_formula ,
787
+ data = preprocessed_data ,
778
788
).fit ()
779
789
780
- novCEA ["pxo2" ] = fitBLTDswitch .predict (novCEA )
790
+ preprocessed_data ["pxo2" ] = fit_bltd_switch .predict (preprocessed_data )
781
791
782
792
# IPCW step 3: For each individual at each time, compute the inverse probability of remaining uncensored
783
793
# Estimate the probabilities of remaining ‘un-switched’ and hence the weights
784
794
785
- novCEA ["num" ] = 1 - novCEA ["pxo1" ]
786
- novCEA ["denom" ] = 1 - novCEA ["pxo2" ]
787
- novCEA [["num" , "denom" ]] = novCEA .sort_values (["id" , "time" ]).groupby ("id" )[["num" , "denom" ]].cumprod ()
795
+ preprocessed_data ["num" ] = 1 - preprocessed_data ["pxo1" ]
796
+ preprocessed_data ["denom" ] = 1 - preprocessed_data ["pxo2" ]
797
+ preprocessed_data [["num" , "denom" ]] = (
798
+ preprocessed_data .sort_values (["id" , "time" ]).groupby ("id" )[["num" , "denom" ]].cumprod ()
799
+ )
788
800
789
- assert not novCEA ["num" ].isnull ().any (), f"{ len (novCEA ['num' ].isnull ())} null numerator values"
790
- assert not novCEA ["denom" ].isnull ().any (), f"{ len (novCEA ['denom' ].isnull ())} null denom values"
801
+ assert (
802
+ not preprocessed_data ["num" ].isnull ().any ()
803
+ ), f"{ len (preprocessed_data ['num' ].isnull ())} null numerator values"
804
+ assert (
805
+ not preprocessed_data ["denom" ].isnull ().any ()
806
+ ), f"{ len (preprocessed_data ['denom' ].isnull ())} null denom values"
791
807
792
- novCEA ["weight" ] = 1 / novCEA ["denom" ]
793
- novCEA ["sweight" ] = novCEA ["num" ] / novCEA ["denom" ]
808
+ preprocessed_data ["weight" ] = 1 / preprocessed_data ["denom" ]
809
+ preprocessed_data ["sweight" ] = preprocessed_data ["num" ] / preprocessed_data ["denom" ]
794
810
795
- novCEA_KM = novCEA .loc [novCEA ["xo_t_do" ] == 0 ].copy ()
796
- novCEA_KM ["tin" ] = novCEA_KM ["time" ]
797
- novCEA_KM ["tout" ] = pd .concat (
798
- [(novCEA_KM ["time" ] + self .timesteps_per_intervention ), novCEA_KM ["fault_time" ]], axis = 1
811
+ preprocessed_data_km = preprocessed_data .loc [preprocessed_data ["xo_t_do" ] == 0 ].copy ()
812
+ preprocessed_data_km ["tin" ] = preprocessed_data_km ["time" ]
813
+ preprocessed_data_km ["tout" ] = pd .concat (
814
+ [(preprocessed_data_km ["time" ] + self .timesteps_per_intervention ), preprocessed_data_km ["fault_time" ]],
815
+ axis = 1 ,
799
816
).min (axis = 1 )
800
817
801
- assert (
802
- novCEA_KM ["tin" ] <= novCEA_KM ["tout" ]
803
- ).all (), f"Left before joining\n { novCEA_KM .loc [novCEA_KM ['tin' ] >= novCEA_KM ['tout' ]]} "
804
-
805
- return novCEA_KM
806
-
807
- def estimate_hazard_ratio (self ):
808
- """
809
- Estimate the hazard ratio.
810
- """
818
+ assert (preprocessed_data_km ["tin" ] <= preprocessed_data_km ["tout" ]).all (), (
819
+ f"Left before joining\n "
820
+ f"{ preprocessed_data_km .loc [preprocessed_data_km ['tin' ] >= preprocessed_data_km ['tout' ]]} "
821
+ )
811
822
812
823
# IPCW step 4: Use these weights in a weighted analysis of the outcome model
813
824
# Estimate the KM graph and IPCW hazard ratio using Cox regression.
814
825
cox_ph = CoxPHFitter ()
815
826
cox_ph .fit (
816
- df = self . df ,
827
+ df = preprocessed_data_km ,
817
828
duration_col = "tout" ,
818
829
event_col = "fault_t_do" ,
819
830
weights_col = "weight" ,
0 commit comments