Skip to content

Commit 392387c

Browse files
committed
pytest and linting
1 parent e2aef75 commit 392387c

File tree

6 files changed

+172
-70
lines changed

6 files changed

+172
-70
lines changed

causal_testing/json_front/json_class.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def _execute_test_case(
276276
test_passes = causal_test_case.expected_causal_effect.apply(causal_test_result)
277277

278278
if "coverage" in test and test["coverage"]:
279-
adequacy_metric = DataAdequacy(causal_test_case, estimation_model, self.data_collector)
279+
adequacy_metric = DataAdequacy(causal_test_case, estimation_model)
280280
adequacy_metric.measure_adequacy()
281281
causal_test_result.adequacy = adequacy_metric
282282

causal_testing/specification/capabilities.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
"""
2+
This module contains the Capability and TreatmentSequence classes to implement
3+
treatment sequences that operate over time.
4+
"""
15
from causal_testing.specification.variable import Variable
26

37

@@ -14,7 +18,7 @@ def __init__(self, variable: Variable, value: any, start_time: int, end_time: fl
1418

1519
def __eq__(self, other):
1620
return (
17-
type(other) == type(self)
21+
isinstance(other, type(self))
1822
and self.variable == other.variable
1923
and self.value == other.value
2024
and self.start_time == other.start_time
@@ -44,7 +48,7 @@ def __init__(self, timesteps_per_intervention, capabilities):
4448
)
4549
]
4650
# This is a bodge so that causal test adequacy works
47-
self.name = tuple([c.variable for c in self.capabilities])
51+
self.name = tuple(c.variable for c in self.capabilities)
4852

4953
def set_value(self, index: int, value: float):
5054
"""

causal_testing/testing/causal_test_adequacy.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from lifelines.exceptions import ConvergenceError
1010

1111
from causal_testing.testing.causal_test_suite import CausalTestSuite
12-
from causal_testing.data_collection.data_collector import DataCollector
1312
from causal_testing.specification.causal_dag import CausalDAG
1413
from causal_testing.testing.estimators import Estimator
1514
from causal_testing.testing.causal_test_case import CausalTestCase
@@ -72,17 +71,16 @@ class DataAdequacy:
7271
- Zero kurtosis is optimal.
7372
"""
7473

74+
# pylint: disable=too-many-instance-attributes
7575
def __init__(
7676
self,
7777
test_case: CausalTestCase,
7878
estimator: Estimator,
79-
data_collector: DataCollector,
8079
bootstrap_size: int = 100,
8180
group_by=None,
8281
):
8382
self.test_case = test_case
8483
self.estimator = estimator
85-
self.data_collector = data_collector
8684
self.kurtosis = None
8785
self.outcomes = None
8886
self.successful = None
@@ -104,7 +102,7 @@ def measure_adequacy(self):
104102
else:
105103
estimator.df = estimator.df.sample(len(estimator.df), replace=True, random_state=i)
106104
try:
107-
results.append(self.test_case.execute_test(estimator, self.data_collector))
105+
results.append(self.test_case.execute_test(estimator, None))
108106
except LinAlgError:
109107
continue
110108
except ConvergenceError:
@@ -129,7 +127,7 @@ def convert_to_df(field):
129127
effect_estimate = pd.concat(results["effect_estimate"].tolist(), axis=1).transpose().reset_index(drop=True)
130128
self.kurtosis = effect_estimate.kurtosis()
131129
self.outcomes = sum(filter(lambda x: x is not None, outcomes))
132-
self.successful = sum([x is not None for x in outcomes])
130+
self.successful = sum(x is not None for x in outcomes)
133131

134132
def to_dict(self):
135133
"Returns the adequacy object as a dictionary."

causal_testing/testing/estimators.py

Lines changed: 71 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -608,17 +608,18 @@ class IPCWEstimator(Estimator):
608608
for sequences of treatments over time-varying data.
609609
"""
610610

611+
# pylint: disable=too-many-arguments
612+
# pylint: disable=too-many-instance-attributes
611613
def __init__(
612614
self,
613615
df: pd.DataFrame,
614616
timesteps_per_intervention: int,
615617
control_strategy: TreatmentSequence,
616618
treatment_strategy: TreatmentSequence,
617619
outcome: str,
618-
min: float,
619-
max: float,
620-
fitBLswitch_formula: str,
621-
fitBLTDswitch_formula: str,
620+
fault_column: str,
621+
fit_bl_switch_formula: str,
622+
fit_bltd_switch_formula: str,
622623
eligibility=None,
623624
alpha: float = 0.05,
624625
query: str = "",
@@ -638,13 +639,12 @@ def __init__(
638639
self.control_strategy = control_strategy
639640
self.treatment_strategy = treatment_strategy
640641
self.outcome = outcome
641-
self.min = min
642-
self.max = max
642+
self.fault_column = fault_column
643643
self.timesteps_per_intervention = timesteps_per_intervention
644-
self.fitBLswitch_formula = fitBLswitch_formula
645-
self.fitBLTDswitch_formula = fitBLTDswitch_formula
644+
self.fit_bl_switch_formula = fit_bl_switch_formula
645+
self.fit_bltd_switch_formula = fit_bltd_switch_formula
646646
self.eligibility = eligibility
647-
self.df = self.preprocess_data(df)
647+
self.df = df
648648

649649
def add_modelling_assumptions(self):
650650
self.modelling_assumptions.append("The variables in the data vary over time.")
@@ -667,21 +667,20 @@ def setup_xo_t_do(self, strategy_assigned: list, strategy_followed: list, eligib
667667
).astype("boolean")
668668
mask = mask | ~eligible
669669
mask.reset_index(inplace=True, drop=True)
670-
false = mask.loc[mask == True]
670+
false = mask.loc[mask]
671671
if false.empty:
672672
return np.zeros(len(mask))
673-
else:
674-
mask = (mask * 1).tolist()
675-
cutoff = false.index[0] + 1
676-
return mask[:cutoff] + ([None] * (len(mask) - cutoff))
673+
mask = (mask * 1).tolist()
674+
cutoff = false.index[0] + 1
675+
return mask[:cutoff] + ([None] * (len(mask) - cutoff))
677676

678677
def setup_fault_t_do(self, individual: pd.DataFrame):
679678
"""
680679
Return a binary sequence with each bit representing whether the current
681680
index is the time point at which the event of interest (i.e. a fault)
682681
occurred.
683682
"""
684-
fault = individual[individual["within_safe_range"] == False]
683+
fault = individual[~individual[self.fault_column]]
685684
fault_t_do = pd.Series(np.zeros(len(individual)), index=individual.index)
686685

687686
if not fault.empty:
@@ -702,39 +701,43 @@ def setup_fault_time(self, individual, perturbation=-0.001):
702701
"""
703702
Return the time at which the event of interest (i.e. a fault) occurred.
704703
"""
705-
fault = individual[individual["within_safe_range"] == False]
704+
fault = individual[~individual[self.fault_column]]
706705
fault_time = (
707706
individual["time"].loc[fault.index[0]]
708707
if not fault.empty
709708
else (individual["time"].max() + self.timesteps_per_intervention)
710709
)
711710
return pd.DataFrame({"fault_time": np.repeat(fault_time + perturbation, len(individual))})
712711

713-
def preprocess_data(self, df):
712+
def preprocess_data(self):
714713
"""
715714
Set up the treatment-specific columns in the data that are needed to estimate the hazard ratio.
716715
"""
717-
df["trtrand"] = None # treatment/control arm
718-
df["xo_t_do"] = None # did the individual deviate from the treatment of interest here?
719-
df["eligible"] = df.eval(self.eligibility) if self.eligibility is not None else True
716+
self.df["trtrand"] = None # treatment/control arm
717+
self.df["xo_t_do"] = None # did the individual deviate from the treatment of interest here?
718+
self.df["eligible"] = self.df.eval(self.eligibility) if self.eligibility is not None else True
720719

721720
# when did a fault occur?
722-
df["within_safe_range"] = df[self.outcome].between(self.min, self.max)
723-
df["fault_time"] = df.groupby("id")[["within_safe_range", "time"]].apply(self.setup_fault_time).values
724-
df["fault_t_do"] = df.groupby("id")[["id", "time", "within_safe_range"]].apply(self.setup_fault_t_do).values
725-
assert not pd.isnull(df["fault_time"]).any()
721+
self.df["fault_time"] = self.df.groupby("id")[[self.fault_column, "time"]].apply(self.setup_fault_time).values
722+
self.df["fault_t_do"] = (
723+
self.df.groupby("id")[["id", "time", self.fault_column]].apply(self.setup_fault_t_do).values
724+
)
725+
assert not pd.isnull(self.df["fault_time"]).any()
726726

727-
living_runs = df.query("fault_time > 0").loc[
728-
(df["time"] % self.timesteps_per_intervention == 0) & (df["time"] <= self.control_strategy.total_time())
727+
living_runs = self.df.query("fault_time > 0").loc[
728+
(self.df["time"] % self.timesteps_per_intervention == 0)
729+
& (self.df["time"] <= self.control_strategy.total_time())
729730
]
730731

731732
individuals = []
732733
new_id = 0
733734
logging.debug(" Preprocessing groups")
734-
for id, individual in living_runs.groupby("id"):
735-
assert (
736-
sum(individual["fault_t_do"]) <= 1
737-
), f"Error initialising fault_t_do for individual\n{individual[['id', 'time', 'fault_time', 'fault_t_do']]}\nwith fault at {individual.fault_time.iloc[0]}"
735+
for _, individual in living_runs.groupby("id"):
736+
assert sum(individual["fault_t_do"]) <= 1, (
737+
f"Error initialising fault_t_do for individual\n"
738+
f"{individual[['id', 'time', 'fault_time', 'fault_t_do']]}\n"
739+
"with fault at {individual.fault_time.iloc[0]}"
740+
)
738741

739742
strategy_followed = [
740743
Capability(
@@ -761,59 +764,67 @@ def preprocess_data(self, df):
761764
if len(individuals) == 0:
762765
raise ValueError("No individuals followed either strategy.")
763766

764-
novCEA = pd.concat(individuals)
767+
return pd.concat(individuals)
765768

766-
if novCEA["fault_t_do"].sum() == 0:
769+
def estimate_hazard_ratio(self):
770+
"""
771+
Estimate the hazard ratio.
772+
"""
773+
774+
preprocessed_data = self.preprocess_data()
775+
776+
if preprocessed_data["fault_t_do"].sum() == 0:
767777
raise ValueError("No recorded faults")
768778

769779
# Use logistic regression to predict switching given baseline covariates
770-
fitBLswitch = smf.logit(self.fitBLswitch_formula, data=novCEA).fit()
780+
fit_bl_switch = smf.logit(self.fit_bl_switch_formula, data=preprocessed_data).fit()
771781

772-
novCEA["pxo1"] = fitBLswitch.predict(novCEA)
782+
preprocessed_data["pxo1"] = fit_bl_switch.predict(preprocessed_data)
773783

774784
# Use logistic regression to predict switching given baseline and time-updated covariates (model S12)
775-
fitBLTDswitch = smf.logit(
776-
self.fitBLTDswitch_formula,
777-
data=novCEA,
785+
fit_bltd_switch = smf.logit(
786+
self.fit_bltd_switch_formula,
787+
data=preprocessed_data,
778788
).fit()
779789

780-
novCEA["pxo2"] = fitBLTDswitch.predict(novCEA)
790+
preprocessed_data["pxo2"] = fit_bltd_switch.predict(preprocessed_data)
781791

782792
# IPCW step 3: For each individual at each time, compute the inverse probability of remaining uncensored
783793
# Estimate the probabilities of remaining ‘un-switched’ and hence the weights
784794

785-
novCEA["num"] = 1 - novCEA["pxo1"]
786-
novCEA["denom"] = 1 - novCEA["pxo2"]
787-
novCEA[["num", "denom"]] = novCEA.sort_values(["id", "time"]).groupby("id")[["num", "denom"]].cumprod()
795+
preprocessed_data["num"] = 1 - preprocessed_data["pxo1"]
796+
preprocessed_data["denom"] = 1 - preprocessed_data["pxo2"]
797+
preprocessed_data[["num", "denom"]] = (
798+
preprocessed_data.sort_values(["id", "time"]).groupby("id")[["num", "denom"]].cumprod()
799+
)
788800

789-
assert not novCEA["num"].isnull().any(), f"{len(novCEA['num'].isnull())} null numerator values"
790-
assert not novCEA["denom"].isnull().any(), f"{len(novCEA['denom'].isnull())} null denom values"
801+
assert (
802+
not preprocessed_data["num"].isnull().any()
803+
), f"{len(preprocessed_data['num'].isnull())} null numerator values"
804+
assert (
805+
not preprocessed_data["denom"].isnull().any()
806+
), f"{len(preprocessed_data['denom'].isnull())} null denom values"
791807

792-
novCEA["weight"] = 1 / novCEA["denom"]
793-
novCEA["sweight"] = novCEA["num"] / novCEA["denom"]
808+
preprocessed_data["weight"] = 1 / preprocessed_data["denom"]
809+
preprocessed_data["sweight"] = preprocessed_data["num"] / preprocessed_data["denom"]
794810

795-
novCEA_KM = novCEA.loc[novCEA["xo_t_do"] == 0].copy()
796-
novCEA_KM["tin"] = novCEA_KM["time"]
797-
novCEA_KM["tout"] = pd.concat(
798-
[(novCEA_KM["time"] + self.timesteps_per_intervention), novCEA_KM["fault_time"]], axis=1
811+
preprocessed_data_km = preprocessed_data.loc[preprocessed_data["xo_t_do"] == 0].copy()
812+
preprocessed_data_km["tin"] = preprocessed_data_km["time"]
813+
preprocessed_data_km["tout"] = pd.concat(
814+
[(preprocessed_data_km["time"] + self.timesteps_per_intervention), preprocessed_data_km["fault_time"]],
815+
axis=1,
799816
).min(axis=1)
800817

801-
assert (
802-
novCEA_KM["tin"] <= novCEA_KM["tout"]
803-
).all(), f"Left before joining\n{novCEA_KM.loc[novCEA_KM['tin'] >= novCEA_KM['tout']]}"
804-
805-
return novCEA_KM
806-
807-
def estimate_hazard_ratio(self):
808-
"""
809-
Estimate the hazard ratio.
810-
"""
818+
assert (preprocessed_data_km["tin"] <= preprocessed_data_km["tout"]).all(), (
819+
f"Left before joining\n"
820+
f"{preprocessed_data_km.loc[preprocessed_data_km['tin'] >= preprocessed_data_km['tout']]}"
821+
)
811822

812823
# IPCW step 4: Use these weights in a weighted analysis of the outcome model
813824
# Estimate the KM graph and IPCW hazard ratio using Cox regression.
814825
cox_ph = CoxPHFitter()
815826
cox_ph.fit(
816-
df=self.df,
827+
df=preprocessed_data_km,
817828
duration_col="tout",
818829
event_col="fault_t_do",
819830
weights_col="weight",

tests/data/temporal_data.csv

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
t,outcome,id,time
2+
0,1,0,0
3+
0,1,0,1
4+
0,1,0,2
5+
0,0,0,3
6+
0,0,0,4
7+
0,1,1,0
8+
0,1,1,1
9+
0,1,1,2
10+
0,0,1,3
11+
0,0,1,4
12+
0,1,2,0
13+
0,1,2,1
14+
0,1,2,2
15+
0,0,2,3
16+
0,0,2,4
17+
0,1,3,0
18+
0,1,3,1
19+
0,1,3,2
20+
0,0,3,3
21+
0,0,3,4
22+
0,1,4,0
23+
0,1,4,1
24+
0,1,4,2
25+
0,0,4,3
26+
0,0,4,4
27+
0,1,5,0
28+
0,1,5,1
29+
0,1,5,2
30+
0,0,5,3
31+
0,0,5,4
32+
0,1,6,0
33+
0,1,6,1
34+
0,0,6,2
35+
0,0,6,3
36+
0,0,6,4
37+
1,1,7,0
38+
1,1,7,1
39+
1,0,7,2
40+
1,0,7,3
41+
1,0,7,4
42+
1,1,8,0
43+
1,1,8,1
44+
1,0,8,2
45+
1,0,8,3
46+
1,0,8,4
47+
1,1,9,0
48+
1,1,9,1
49+
1,0,9,2
50+
1,0,9,3
51+
1,0,9,4
52+
1,1,10,0
53+
1,1,10,1
54+
1,0,10,2
55+
1,0,10,3
56+
1,0,10,4
57+
1,1,11,0
58+
1,1,11,1
59+
1,0,11,2
60+
1,0,11,3
61+
1,0,11,4

0 commit comments

Comments
 (0)