Skip to content

Commit 6ceef1a

Browse files
committed
Some more linting
1 parent 392387c commit 6ceef1a

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

causal_testing/testing/causal_test_adequacy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ def measure_adequacy(self):
107107
continue
108108
except ConvergenceError:
109109
continue
110+
except ValueError:
111+
continue
110112
outcomes = [self.test_case.expected_causal_effect.apply(c) for c in results]
111113
results = pd.DataFrame(c.to_dict() for c in results)[["effect_estimate", "ci_low", "ci_high"]]
112114

causal_testing/testing/estimators.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,6 @@ def __init__(
622622
fit_bltd_switch_formula: str,
623623
eligibility=None,
624624
alpha: float = 0.05,
625-
query: str = "",
626625
):
627626
super().__init__(
628627
[c.variable for c in treatment_strategy.capabilities],
@@ -633,7 +632,7 @@ def __init__(
633632
df,
634633
None,
635634
alpha=alpha,
636-
query=query,
635+
query="",
637636
)
638637
self.timesteps_per_intervention = timesteps_per_intervention
639638
self.control_strategy = control_strategy
@@ -645,6 +644,7 @@ def __init__(
645644
self.fit_bltd_switch_formula = fit_bltd_switch_formula
646645
self.eligibility = eligibility
647646
self.df = df
647+
self.preprocess_data()
648648

649649
def add_modelling_assumptions(self):
650650
self.modelling_assumptions.append("The variables in the data vary over time.")
@@ -764,27 +764,27 @@ def preprocess_data(self):
764764
if len(individuals) == 0:
765765
raise ValueError("No individuals followed either strategy.")
766766

767-
return pd.concat(individuals)
767+
self.df = pd.concat(individuals)
768768

769769
def estimate_hazard_ratio(self):
770770
"""
771771
Estimate the hazard ratio.
772772
"""
773773

774-
preprocessed_data = self.preprocess_data()
775-
776-
if preprocessed_data["fault_t_do"].sum() == 0:
774+
if self.df["fault_t_do"].sum() == 0:
777775
raise ValueError("No recorded faults")
778776

777+
preprocessed_data = self.df.loc[self.df["xo_t_do"] == 0].copy()
778+
779779
# Use logistic regression to predict switching given baseline covariates
780-
fit_bl_switch = smf.logit(self.fit_bl_switch_formula, data=preprocessed_data).fit()
780+
fit_bl_switch = smf.logit(self.fit_bl_switch_formula, data=self.df).fit()
781781

782782
preprocessed_data["pxo1"] = fit_bl_switch.predict(preprocessed_data)
783783

784784
# Use logistic regression to predict switching given baseline and time-updated covariates (model S12)
785785
fit_bltd_switch = smf.logit(
786786
self.fit_bltd_switch_formula,
787-
data=preprocessed_data,
787+
data=self.df,
788788
).fit()
789789

790790
preprocessed_data["pxo2"] = fit_bltd_switch.predict(preprocessed_data)
@@ -808,23 +808,21 @@ def estimate_hazard_ratio(self):
808808
preprocessed_data["weight"] = 1 / preprocessed_data["denom"]
809809
preprocessed_data["sweight"] = preprocessed_data["num"] / preprocessed_data["denom"]
810810

811-
preprocessed_data_km = preprocessed_data.loc[preprocessed_data["xo_t_do"] == 0].copy()
812-
preprocessed_data_km["tin"] = preprocessed_data_km["time"]
813-
preprocessed_data_km["tout"] = pd.concat(
814-
[(preprocessed_data_km["time"] + self.timesteps_per_intervention), preprocessed_data_km["fault_time"]],
811+
preprocessed_data["tin"] = preprocessed_data["time"]
812+
preprocessed_data["tout"] = pd.concat(
813+
[(preprocessed_data["time"] + self.timesteps_per_intervention), preprocessed_data["fault_time"]],
815814
axis=1,
816815
).min(axis=1)
817816

818-
assert (preprocessed_data_km["tin"] <= preprocessed_data_km["tout"]).all(), (
819-
f"Left before joining\n"
820-
f"{preprocessed_data_km.loc[preprocessed_data_km['tin'] >= preprocessed_data_km['tout']]}"
817+
assert (preprocessed_data["tin"] <= preprocessed_data["tout"]).all(), (
818+
f"Left before joining\n" f"{preprocessed_data.loc[preprocessed_data['tin'] >= preprocessed_data['tout']]}"
821819
)
822820

823821
# IPCW step 4: Use these weights in a weighted analysis of the outcome model
824822
# Estimate the KM graph and IPCW hazard ratio using Cox regression.
825823
cox_ph = CoxPHFitter()
826824
cox_ph.fit(
827-
df=preprocessed_data_km,
825+
df=preprocessed_data,
828826
duration_col="tout",
829827
event_col="fault_t_do",
830828
weights_col="weight",

0 commit comments

Comments
 (0)