Skip to content

Commit 89eeaed

Browse files
committed
Support for treatment strategies with interventions at non-consecutive timesteps
1 parent d5f0dad commit 89eeaed

File tree

9 files changed

+53
-158
lines changed

9 files changed

+53
-158
lines changed

causal_testing/estimation/ipcw_estimator.py

Lines changed: 45 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,21 @@
22

33
import logging
44
from math import ceil
5+
from typing import Any
56

67
import numpy as np
78
import pandas as pd
89
import statsmodels.formula.api as smf
910
from lifelines import CoxPHFitter
1011

11-
from causal_testing.specification.capabilities import TreatmentSequence, Capability
1212
from causal_testing.estimation.abstract_estimator import Estimator
1313

1414
logger = logging.getLogger(__name__)
1515

1616

1717
class IPCWEstimator(Estimator):
1818
"""
19-
Class to perform inverse probability of censoring weighting (IPCW) estimation
19+
Class to perform Inverse Probability of Censoring Weighting (IPCW) estimation
2020
for sequences of treatments over time-varying data.
2121
"""
2222

@@ -25,37 +25,57 @@ class IPCWEstimator(Estimator):
2525
def __init__(
2626
self,
2727
df: pd.DataFrame,
28-
timesteps_per_intervention: int,
29-
control_strategy: TreatmentSequence,
30-
treatment_strategy: TreatmentSequence,
28+
timesteps_per_observation: int,
29+
control_strategy: list[tuple[int, str, Any]],
30+
treatment_strategy: list[tuple[int, str, Any]],
3131
outcome: str,
32-
fault_column: str,
32+
status_column: str,
3333
fit_bl_switch_formula: str,
3434
fit_bltd_switch_formula: str,
3535
eligibility=None,
3636
alpha: float = 0.05,
37+
total_time: float = None,
3738
):
39+
"""
40+
Initialise IPCWEstimator.
41+
42+
:param: df: Input DataFrame containing time-varying data.
43+
:param: timesteps_per_observation: Number of timesteps per observation.
44+
:param: control_strategy: The control strategy, with entries of the form (timestep, variable, value).
45+
:param: treatment_strategy: The treatment strategy, with entries of the form (timestep, variable, value).
46+
:param: outcome: Name of the outcome column in the DataFrame.
47+
:param: status_column: Name of the status column in the DataFrame, which should be True for operating normally, False for a fault.
48+
:param: fit_bl_switch_formula: Formula for fitting the baseline switch model.
49+
:param: fit_bltd_switch_formula: Formula for fitting the baseline time-dependent switch model.
50+
:param: eligibility: Function to determine eligibility for treatment. Defaults to None for "always eligible".
51+
:param: alpha: Significance level for hypothesis testing. Defaults to 0.05.
52+
:param: total_time: Total time for the analysis. Defaults to one plus the length of of the strategy (control or
53+
treatment) with the most elements multiplied by `timesteps_per_observation`.
54+
"""
3855
super().__init__(
39-
[c.variable for c in treatment_strategy.capabilities],
40-
[c.value for c in treatment_strategy.capabilities],
41-
[c.value for c in control_strategy.capabilities],
56+
[var for _, var, _ in treatment_strategy],
57+
[val for _, _, val in treatment_strategy],
58+
[val for _, _, val in control_strategy],
4259
None,
4360
outcome,
4461
df,
4562
None,
4663
alpha=alpha,
4764
query="",
4865
)
49-
self.timesteps_per_intervention = timesteps_per_intervention
66+
self.timesteps_per_observation = timesteps_per_observation
5067
self.control_strategy = control_strategy
5168
self.treatment_strategy = treatment_strategy
5269
self.outcome = outcome
53-
self.fault_column = fault_column
54-
self.timesteps_per_intervention = timesteps_per_intervention
70+
self.status_column = status_column
5571
self.fit_bl_switch_formula = fit_bl_switch_formula
5672
self.fit_bltd_switch_formula = fit_bltd_switch_formula
5773
self.eligibility = eligibility
5874
self.df = df
75+
if total_time is None:
76+
self.total_time = (
77+
max(len(self.control_strategy), len(self.treatment_strategy)) + 1
78+
) * self.timesteps_per_observation
5979
self.preprocess_data()
6080

6181
def add_modelling_assumptions(self):
@@ -92,16 +112,16 @@ def setup_fault_t_do(self, individual: pd.DataFrame):
92112
index is the time point at which the event of interest (i.e. a fault)
93113
occurred.
94114
"""
95-
fault = individual[~individual[self.fault_column]]
115+
fault = individual[~individual[self.status_column]]
96116
fault_t_do = pd.Series(np.zeros(len(individual)), index=individual.index)
97117

98118
if not fault.empty:
99119
fault_time = individual["time"].loc[fault.index[0]]
100120
# Ceiling to nearest observation point
101-
fault_time = ceil(fault_time / self.timesteps_per_intervention) * self.timesteps_per_intervention
121+
fault_time = ceil(fault_time / self.timesteps_per_observation) * self.timesteps_per_observation
102122
# Set the correct observation point to be the fault time of doing (fault_t_do)
103123
observations = individual.loc[
104-
(individual["time"] % self.timesteps_per_intervention == 0) & (individual["time"] < fault_time)
124+
(individual["time"] % self.timesteps_per_observation == 0) & (individual["time"] < fault_time)
105125
]
106126
if not observations.empty:
107127
fault_t_do.loc[observations.index[0]] = 1
@@ -113,11 +133,11 @@ def setup_fault_time(self, individual: pd.DataFrame, perturbation: float = -0.00
113133
"""
114134
Return the time at which the event of interest (i.e. a fault) occurred.
115135
"""
116-
fault = individual[~individual[self.fault_column]]
136+
fault = individual[~individual[self.status_column]]
117137
fault_time = (
118138
individual["time"].loc[fault.index[0]]
119139
if not fault.empty
120-
else (individual["time"].max() + self.timesteps_per_intervention)
140+
else (individual["time"].max() + self.timesteps_per_observation)
121141
)
122142
return pd.DataFrame({"fault_time": np.repeat(fault_time + perturbation, len(individual))})
123143

@@ -130,15 +150,14 @@ def preprocess_data(self):
130150
self.df["eligible"] = self.df.eval(self.eligibility) if self.eligibility is not None else True
131151

132152
# when did a fault occur?
133-
self.df["fault_time"] = self.df.groupby("id")[[self.fault_column, "time"]].apply(self.setup_fault_time).values
153+
self.df["fault_time"] = self.df.groupby("id")[[self.status_column, "time"]].apply(self.setup_fault_time).values
134154
self.df["fault_t_do"] = (
135-
self.df.groupby("id")[["id", "time", self.fault_column]].apply(self.setup_fault_t_do).values
155+
self.df.groupby("id")[["id", "time", self.status_column]].apply(self.setup_fault_t_do).values
136156
)
137157
assert not pd.isnull(self.df["fault_time"]).any()
138158

139159
living_runs = self.df.query("fault_time > 0").loc[
140-
(self.df["time"] % self.timesteps_per_intervention == 0)
141-
& (self.df["time"] <= self.control_strategy.total_time())
160+
(self.df["time"] % self.timesteps_per_observation == 0) & (self.df["time"] <= self.total_time)
142161
]
143162

144163
individuals = []
@@ -152,25 +171,20 @@ def preprocess_data(self):
152171
)
153172

154173
strategy_followed = [
155-
Capability(
156-
c.variable,
157-
individual.loc[individual["time"] == c.start_time, c.variable].values[0],
158-
c.start_time,
159-
c.end_time,
160-
)
161-
for c in self.treatment_strategy.capabilities
174+
(t, var, individual.loc[individual["time"] == t, var].values[0])
175+
for t, var, val in self.treatment_strategy
162176
]
163177

164178
# Control flow:
165179
# Individuals that start off in both arms, need cloning (hence incrementing the ID within the if statement)
166180
# Individuals that don't start off in either arm are left out
167181
for inx, strategy_assigned in [(0, self.control_strategy), (1, self.treatment_strategy)]:
168-
if strategy_assigned.capabilities[0] == strategy_followed[0] and individual.eligible.iloc[0]:
182+
if strategy_assigned[0] == strategy_followed[0] and individual.eligible.iloc[0]:
169183
individual["id"] = new_id
170184
new_id += 1
171185
individual["trtrand"] = inx
172186
individual["xo_t_do"] = self.setup_xo_t_do(
173-
strategy_assigned.capabilities, strategy_followed, individual["eligible"]
187+
strategy_assigned, strategy_followed, individual["eligible"]
174188
)
175189
individuals.append(individual.loc[individual["time"] <= individual["fault_time"]].copy())
176190
if len(individuals) == 0:
@@ -222,7 +236,7 @@ def estimate_hazard_ratio(self):
222236

223237
preprocessed_data["tin"] = preprocessed_data["time"]
224238
preprocessed_data["tout"] = pd.concat(
225-
[(preprocessed_data["time"] + self.timesteps_per_intervention), preprocessed_data["fault_time"]],
239+
[(preprocessed_data["time"] + self.timesteps_per_observation), preprocessed_data["fault_time"]],
226240
axis=1,
227241
).min(axis=1)
228242

causal_testing/specification/capabilities.py

Lines changed: 0 additions & 77 deletions
This file was deleted.

tests/estimation_tests/test_cubic_spline_estimator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import matplotlib.pyplot as plt
55
from causal_testing.specification.variable import Input
66
from causal_testing.utils.validation import CausalValidator
7-
from causal_testing.specification.capabilities import TreatmentSequence
87

98
from causal_testing.estimation.cubic_spline_estimator import CubicSplineRegressionEstimator
109
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator

tests/estimation_tests/test_instrumental_variable_estimator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import matplotlib.pyplot as plt
55
from causal_testing.specification.variable import Input
66
from causal_testing.utils.validation import CausalValidator
7-
from causal_testing.specification.capabilities import TreatmentSequence
87

98
from causal_testing.estimation.instrumental_variable_estimator import InstrumentalVariableEstimator
109

tests/estimation_tests/test_ipcw_estimator.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import matplotlib.pyplot as plt
55
from causal_testing.specification.variable import Input
66
from causal_testing.utils.validation import CausalValidator
7-
from causal_testing.specification.capabilities import TreatmentSequence
87

98
from causal_testing.estimation.ipcw_estimator import IPCWEstimator
109

@@ -16,8 +15,8 @@ class TestIPCWEstimator(unittest.TestCase):
1615

1716
def test_estimate_hazard_ratio(self):
1817
timesteps_per_intervention = 1
19-
control_strategy = TreatmentSequence(timesteps_per_intervention, [("t", 0), ("t", 0), ("t", 0)])
20-
treatment_strategy = TreatmentSequence(timesteps_per_intervention, [("t", 1), ("t", 1), ("t", 1)])
18+
control_strategy = [(t, "t", 0) for t in range(1, 4, timesteps_per_intervention)]
19+
treatment_strategy = [(t, "t", 1) for t in range(1, 4, timesteps_per_intervention)]
2120
outcome = "outcome"
2221
fit_bl_switch_formula = "xo_t_do ~ time"
2322
df = pd.read_csv("tests/resources/data/temporal_data.csv")
@@ -38,8 +37,8 @@ def test_estimate_hazard_ratio(self):
3837

3938
def test_invalid_treatment_strategies(self):
4039
timesteps_per_intervention = 1
41-
control_strategy = TreatmentSequence(timesteps_per_intervention, [("t", 0), ("t", 0), ("t", 0)])
42-
treatment_strategy = TreatmentSequence(timesteps_per_intervention, [("t", 1), ("t", 1), ("t", 1)])
40+
control_strategy = [(t, "t", 0) for t in range(1, 4, timesteps_per_intervention)]
41+
treatment_strategy = [(t, "t", 1) for t in range(1, 4, timesteps_per_intervention)]
4342
outcome = "outcome"
4443
fit_bl_switch_formula = "xo_t_do ~ time"
4544
df = pd.read_csv("tests/resources/data/temporal_data.csv")
@@ -60,8 +59,8 @@ def test_invalid_treatment_strategies(self):
6059

6160
def test_invalid_fault_t_do(self):
6261
timesteps_per_intervention = 1
63-
control_strategy = TreatmentSequence(timesteps_per_intervention, [("t", 0), ("t", 0), ("t", 0)])
64-
treatment_strategy = TreatmentSequence(timesteps_per_intervention, [("t", 1), ("t", 1), ("t", 1)])
62+
control_strategy = [(t, "t", 0) for t in range(1, 4, timesteps_per_intervention)]
63+
treatment_strategy = [(t, "t", 1) for t in range(1, 4, timesteps_per_intervention)]
6564
outcome = "outcome"
6665
fit_bl_switch_formula = "xo_t_do ~ time"
6766
df = pd.read_csv("tests/resources/data/temporal_data.csv")

tests/estimation_tests/test_linear_regression_estimator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import matplotlib.pyplot as plt
55
from causal_testing.specification.variable import Input
66
from causal_testing.utils.validation import CausalValidator
7-
from causal_testing.specification.capabilities import TreatmentSequence
87

98
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
109
from causal_testing.estimation.genetic_programming_regression_fitter import reciprocal

tests/estimation_tests/test_logistic_regression_estimator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import matplotlib.pyplot as plt
55
from causal_testing.specification.variable import Input
66
from causal_testing.utils.validation import CausalValidator
7-
from causal_testing.specification.capabilities import TreatmentSequence
87
from causal_testing.estimation.logistic_regression_estimator import LogisticRegressionEstimator
98

109

tests/specification_tests/test_capabilities.py

Lines changed: 0 additions & 36 deletions
This file was deleted.

tests/testing_tests/test_causal_test_adequacy.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from causal_testing.specification.variable import Input, Output, Meta
1717
from causal_testing.specification.scenario import Scenario
1818
from causal_testing.specification.causal_specification import CausalSpecification
19-
from causal_testing.specification.capabilities import TreatmentSequence
2019
from causal_testing.testing.causal_test_adequacy import DataAdequacy
2120

2221

@@ -112,8 +111,8 @@ def test_data_adequacy_cateogorical(self):
112111

113112
def test_data_adequacy_group_by(self):
114113
timesteps_per_intervention = 1
115-
control_strategy = TreatmentSequence(timesteps_per_intervention, [("t", 0), ("t", 0), ("t", 0)])
116-
treatment_strategy = TreatmentSequence(timesteps_per_intervention, [("t", 1), ("t", 1), ("t", 1)])
114+
control_strategy = [(t, "t", 0) for t in range(1, 4, timesteps_per_intervention)]
115+
treatment_strategy = [(t, "t", 1) for t in range(1, 4, timesteps_per_intervention)]
117116
outcome = "outcome"
118117
fit_bl_switch_formula = "xo_t_do ~ time"
119118
df = pd.read_csv("tests/resources/data/temporal_data.csv")

0 commit comments

Comments
 (0)