Skip to content

Commit 0d76e92

Browse files
committed
Fixed estimation bug
1 parent b321170 commit 0d76e92

File tree

2 files changed

+29
-34
lines changed

2 files changed

+29
-34
lines changed

causal_testing/testing/causal_test_adequacy.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from itertools import combinations
66
from copy import deepcopy
77
import pandas as pd
8+
from numpy.linalg import LinAlgError
9+
from lifelines.exceptions import ConvergenceError
810

911
from causal_testing.testing.causal_test_suite import CausalTestSuite
1012
from causal_testing.data_collection.data_collector import DataCollector
@@ -101,7 +103,12 @@ def measure_adequacy(self):
101103
estimator.df = estimator.df[estimator.df[self.group_by].isin(ids)]
102104
else:
103105
estimator.df = estimator.df.sample(len(estimator.df), replace=True, random_state=i)
104-
results.append(self.test_case.execute_test(estimator, self.data_collector))
106+
try:
107+
results.append(self.test_case.execute_test(estimator, self.data_collector))
108+
except LinAlgError:
109+
continue
110+
except ConvergenceError:
111+
continue
105112
outcomes = [self.test_case.expected_causal_effect.apply(c) for c in results]
106113
results = pd.DataFrame(c.to_dict() for c in results)[["effect_estimate", "ci_low", "ci_high"]]
107114

causal_testing/testing/estimators.py

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -760,73 +760,60 @@ def preprocess_data(self, df):
760760
individuals.append(individual.loc[individual["time"] <= individual["fault_time"]].copy())
761761
if len(individuals) == 0:
762762
raise ValueError("No individuals followed either strategy.")
763-
return pd.concat(individuals)
764763

765-
def estimate_hazard_ratio(self):
766-
"""
767-
Estimate the hazard ratio.
768-
"""
764+
novCEA = pd.concat(individuals)
769765

770-
if self.df["fault_t_do"].sum() == 0:
766+
if novCEA["fault_t_do"].sum() == 0:
771767
raise ValueError("No recorded faults")
772768

773-
logging.debug(f" {int(self.df['fault_t_do'].sum())}/{len(self.df.groupby('id'))} faulty runs observed")
774-
775769
# Use logistic regression to predict switching given baseline covariates
776-
logging.debug(f" predict switching given baseline covariates: {self.fitBLswitch_formula}")
777-
fitBLswitch = smf.logit(self.fitBLswitch_formula, data=self.df).fit()
770+
fitBLswitch = smf.logit(self.fitBLswitch_formula, data=novCEA).fit()
778771

779-
# Estimate the probability of switching for each patient-observation included in the regression.
780-
novCEA = pd.DataFrame()
781-
novCEA["pxo1"] = fitBLswitch.predict(self.df)
772+
novCEA["pxo1"] = fitBLswitch.predict(novCEA)
782773

783774
# Use logistic regression to predict switching given baseline and time-updated covariates (model S12)
784-
logging.debug(f" predict switching given baseline and time-updated covariates: {self.fitBLTDswitch_formula}")
785-
786775
fitBLTDswitch = smf.logit(
787776
self.fitBLTDswitch_formula,
788-
data=self.df,
777+
data=novCEA,
789778
).fit()
790779

791-
# Estimate the probability of switching for each patient-observation included in the regression.
792-
novCEA["pxo2"] = fitBLTDswitch.predict(self.df)
780+
novCEA["pxo2"] = fitBLTDswitch.predict(novCEA)
793781

794782
# IPCW step 3: For each individual at each time, compute the inverse probability of remaining uncensored
795783
# Estimate the probabilities of remaining ‘un-switched’ and hence the weights
796784

797785
novCEA["num"] = 1 - novCEA["pxo1"]
798786
novCEA["denom"] = 1 - novCEA["pxo2"]
799-
prod = (
800-
pd.concat([self.df, novCEA], axis=1).sort_values(["id", "time"]).groupby("id")[["num", "denom"]].cumprod()
801-
)
802-
novCEA["num"] = prod["num"]
803-
novCEA["denom"] = prod["denom"]
787+
novCEA[["num", "denom"]] = novCEA.sort_values(["id", "time"]).groupby("id")[["num", "denom"]].cumprod()
804788

805789
assert not novCEA["num"].isnull().any(), f"{len(novCEA['num'].isnull())} null numerator values"
806790
assert not novCEA["denom"].isnull().any(), f"{len(novCEA['denom'].isnull())} null denom values"
807791

808792
novCEA["weight"] = 1 / novCEA["denom"]
809793
novCEA["sweight"] = novCEA["num"] / novCEA["denom"]
810794

811-
novCEA_KM = novCEA.loc[self.df["xo_t_do"] == 0].copy()
812-
novCEA_KM["tin"] = self.df["time"]
795+
novCEA_KM = novCEA.loc[novCEA["xo_t_do"] == 0].copy()
796+
novCEA_KM["tin"] = novCEA_KM["time"]
813797
novCEA_KM["tout"] = pd.concat(
814-
[(self.df["time"] + self.timesteps_per_intervention), self.df["fault_time"]], axis=1
798+
[(novCEA_KM["time"] + self.timesteps_per_intervention), novCEA_KM["fault_time"]], axis=1
815799
).min(axis=1)
816800

817801
assert (
818802
novCEA_KM["tin"] <= novCEA_KM["tout"]
819803
).all(), f"Left before joining\n{novCEA_KM.loc[novCEA_KM['tin'] >= novCEA_KM['tout']]}"
820804

821-
novCEA_KM.dropna(axis=1, inplace=True)
822-
novCEA_KM.replace([float("inf")], 100, inplace=True)
805+
return novCEA_KM
806+
807+
def estimate_hazard_ratio(self):
808+
"""
809+
Estimate the hazard ratio.
810+
"""
823811

824812
# IPCW step 4: Use these weights in a weighted analysis of the outcome model
825813
# Estimate the KM graph and IPCW hazard ratio using Cox regression.
826-
cox_ph = CoxPHFitter(alpha=self.alpha)
827-
814+
cox_ph = CoxPHFitter()
828815
cox_ph.fit(
829-
df=pd.concat([self.df, novCEA_KM], axis=1),
816+
df=self.df,
830817
duration_col="tout",
831818
event_col="fault_t_do",
832819
weights_col="weight",
@@ -835,6 +822,7 @@ def estimate_hazard_ratio(self):
835822
formula="trtrand",
836823
entry_col="tin",
837824
)
838-
ci_low, ci_high = sorted(np.exp(cox_ph.confidence_intervals_).T["trtrand"].tolist())
839825

840-
return (cox_ph.hazard_ratios_["trtrand"], ([ci_low], [ci_high]))
826+
ci_low, ci_high = [np.exp(cox_ph.confidence_intervals_)[col] for col in cox_ph.confidence_intervals_.columns]
827+
828+
return (cox_ph.hazard_ratios_, (ci_low, ci_high))

0 commit comments

Comments
 (0)