Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 46 additions & 31 deletions causal_testing/estimation/ipcw_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@

import logging
from math import ceil
from typing import Any

import numpy as np
import pandas as pd
import statsmodels.formula.api as smf
from lifelines import CoxPHFitter

from causal_testing.specification.capabilities import TreatmentSequence, Capability
from causal_testing.estimation.abstract_estimator import Estimator

logger = logging.getLogger(__name__)


class IPCWEstimator(Estimator):
"""
Class to perform inverse probability of censoring weighting (IPCW) estimation
Class to perform Inverse Probability of Censoring Weighting (IPCW) estimation
for sequences of treatments over time-varying data.
"""

Expand All @@ -25,37 +25,58 @@ class IPCWEstimator(Estimator):
def __init__(
self,
df: pd.DataFrame,
timesteps_per_intervention: int,
control_strategy: TreatmentSequence,
treatment_strategy: TreatmentSequence,
timesteps_per_observation: int,
control_strategy: list[tuple[int, str, Any]],
treatment_strategy: list[tuple[int, str, Any]],
outcome: str,
fault_column: str,
status_column: str,
fit_bl_switch_formula: str,
fit_bltd_switch_formula: str,
eligibility=None,
alpha: float = 0.05,
total_time: float = None,
):
"""
Initialise IPCWEstimator.

:param: df: Input DataFrame containing time-varying data.
:param: timesteps_per_observation: Number of timesteps per observation.
:param: control_strategy: The control strategy, with entries of the form (timestep, variable, value).
:param: treatment_strategy: The treatment strategy, with entries of the form (timestep, variable, value).
:param: outcome: Name of the outcome column in the DataFrame.
:param: status_column: Name of the status column in the DataFrame, which should be True for operating normally,
False for a fault.
:param: fit_bl_switch_formula: Formula for fitting the baseline switch model.
:param: fit_bltd_switch_formula: Formula for fitting the baseline time-dependent switch model.
:param: eligibility: Function to determine eligibility for treatment. Defaults to None for "always eligible".
:param: alpha: Significance level for hypothesis testing. Defaults to 0.05.
:param: total_time: Total time for the analysis. Defaults to one plus the length of of the strategy (control or
treatment) with the most elements multiplied by `timesteps_per_observation`.
"""
super().__init__(
[c.variable for c in treatment_strategy.capabilities],
[c.value for c in treatment_strategy.capabilities],
[c.value for c in control_strategy.capabilities],
[var for _, var, _ in treatment_strategy],
[val for _, _, val in treatment_strategy],
[val for _, _, val in control_strategy],
None,
outcome,
df,
None,
alpha=alpha,
query="",
)
self.timesteps_per_intervention = timesteps_per_intervention
self.timesteps_per_observation = timesteps_per_observation
self.control_strategy = control_strategy
self.treatment_strategy = treatment_strategy
self.outcome = outcome
self.fault_column = fault_column
self.timesteps_per_intervention = timesteps_per_intervention
self.status_column = status_column
self.fit_bl_switch_formula = fit_bl_switch_formula
self.fit_bltd_switch_formula = fit_bltd_switch_formula
self.eligibility = eligibility
self.df = df
if total_time is None:
self.total_time = (
max(len(self.control_strategy), len(self.treatment_strategy)) + 1
) * self.timesteps_per_observation
self.preprocess_data()

def add_modelling_assumptions(self):
Expand Down Expand Up @@ -92,16 +113,16 @@ def setup_fault_t_do(self, individual: pd.DataFrame):
index is the time point at which the event of interest (i.e. a fault)
occurred.
"""
fault = individual[~individual[self.fault_column]]
fault = individual[~individual[self.status_column]]
fault_t_do = pd.Series(np.zeros(len(individual)), index=individual.index)

if not fault.empty:
fault_time = individual["time"].loc[fault.index[0]]
# Ceiling to nearest observation point
fault_time = ceil(fault_time / self.timesteps_per_intervention) * self.timesteps_per_intervention
fault_time = ceil(fault_time / self.timesteps_per_observation) * self.timesteps_per_observation
# Set the correct observation point to be the fault time of doing (fault_t_do)
observations = individual.loc[
(individual["time"] % self.timesteps_per_intervention == 0) & (individual["time"] < fault_time)
(individual["time"] % self.timesteps_per_observation == 0) & (individual["time"] < fault_time)
]
if not observations.empty:
fault_t_do.loc[observations.index[0]] = 1
Expand All @@ -113,11 +134,11 @@ def setup_fault_time(self, individual: pd.DataFrame, perturbation: float = -0.00
"""
Return the time at which the event of interest (i.e. a fault) occurred.
"""
fault = individual[~individual[self.fault_column]]
fault = individual[~individual[self.status_column]]
fault_time = (
individual["time"].loc[fault.index[0]]
if not fault.empty
else (individual["time"].max() + self.timesteps_per_intervention)
else (individual["time"].max() + self.timesteps_per_observation)
)
return pd.DataFrame({"fault_time": np.repeat(fault_time + perturbation, len(individual))})

Expand All @@ -130,15 +151,14 @@ def preprocess_data(self):
self.df["eligible"] = self.df.eval(self.eligibility) if self.eligibility is not None else True

# when did a fault occur?
self.df["fault_time"] = self.df.groupby("id")[[self.fault_column, "time"]].apply(self.setup_fault_time).values
self.df["fault_time"] = self.df.groupby("id")[[self.status_column, "time"]].apply(self.setup_fault_time).values
self.df["fault_t_do"] = (
self.df.groupby("id")[["id", "time", self.fault_column]].apply(self.setup_fault_t_do).values
self.df.groupby("id")[["id", "time", self.status_column]].apply(self.setup_fault_t_do).values
)
assert not pd.isnull(self.df["fault_time"]).any()

living_runs = self.df.query("fault_time > 0").loc[
(self.df["time"] % self.timesteps_per_intervention == 0)
& (self.df["time"] <= self.control_strategy.total_time())
(self.df["time"] % self.timesteps_per_observation == 0) & (self.df["time"] <= self.total_time)
]

individuals = []
Expand All @@ -152,25 +172,20 @@ def preprocess_data(self):
)

strategy_followed = [
Capability(
c.variable,
individual.loc[individual["time"] == c.start_time, c.variable].values[0],
c.start_time,
c.end_time,
)
for c in self.treatment_strategy.capabilities
[t, var, individual.loc[individual["time"] == t, var].values[0]]
for t, var, val in self.treatment_strategy
]

# Control flow:
# Individuals that start off in both arms, need cloning (hence incrementing the ID within the if statement)
# Individuals that don't start off in either arm are left out
for inx, strategy_assigned in [(0, self.control_strategy), (1, self.treatment_strategy)]:
if strategy_assigned.capabilities[0] == strategy_followed[0] and individual.eligible.iloc[0]:
if strategy_assigned[0] == strategy_followed[0] and individual.eligible.iloc[0]:
individual["id"] = new_id
new_id += 1
individual["trtrand"] = inx
individual["xo_t_do"] = self.setup_xo_t_do(
strategy_assigned.capabilities, strategy_followed, individual["eligible"]
strategy_assigned, strategy_followed, individual["eligible"]
)
individuals.append(individual.loc[individual["time"] <= individual["fault_time"]].copy())
if len(individuals) == 0:
Expand Down Expand Up @@ -222,7 +237,7 @@ def estimate_hazard_ratio(self):

preprocessed_data["tin"] = preprocessed_data["time"]
preprocessed_data["tout"] = pd.concat(
[(preprocessed_data["time"] + self.timesteps_per_intervention), preprocessed_data["fault_time"]],
[(preprocessed_data["time"] + self.timesteps_per_observation), preprocessed_data["fault_time"]],
axis=1,
).min(axis=1)

Expand Down
77 changes: 0 additions & 77 deletions causal_testing/specification/capabilities.py
Copy link
Contributor

@f-allian f-allian Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@f-allian, is the codecov just failing on this because the codebase has got slightly bigger? It says I have 100% coverage of the patch, so that's all I can think of, but adding in extra code with 100% coverage should be fine right?

@jmafoster1 Looks like the coverage is failing because you've deleted causal_testing/specification/capabilities.py. Do you know if this was intentional? If so, recommitting it should fix it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this was intentional. It is no longer required as the new functionality requires timesteps to be specified explicitly, which is what I wrote capabilities.py to handle implicitly.

This file was deleted.

1 change: 0 additions & 1 deletion tests/estimation_tests/test_cubic_spline_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import matplotlib.pyplot as plt
from causal_testing.specification.variable import Input
from causal_testing.utils.validation import CausalValidator
from causal_testing.specification.capabilities import TreatmentSequence

from causal_testing.estimation.cubic_spline_estimator import CubicSplineRegressionEstimator
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import matplotlib.pyplot as plt
from causal_testing.specification.variable import Input
from causal_testing.utils.validation import CausalValidator
from causal_testing.specification.capabilities import TreatmentSequence

from causal_testing.estimation.instrumental_variable_estimator import InstrumentalVariableEstimator

Expand Down
13 changes: 6 additions & 7 deletions tests/estimation_tests/test_ipcw_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import matplotlib.pyplot as plt
from causal_testing.specification.variable import Input
from causal_testing.utils.validation import CausalValidator
from causal_testing.specification.capabilities import TreatmentSequence

from causal_testing.estimation.ipcw_estimator import IPCWEstimator

Expand All @@ -16,8 +15,8 @@ class TestIPCWEstimator(unittest.TestCase):

def test_estimate_hazard_ratio(self):
timesteps_per_intervention = 1
control_strategy = TreatmentSequence(timesteps_per_intervention, [("t", 0), ("t", 0), ("t", 0)])
treatment_strategy = TreatmentSequence(timesteps_per_intervention, [("t", 1), ("t", 1), ("t", 1)])
control_strategy = [[t, "t", 0] for t in range(1, 4, timesteps_per_intervention)]
treatment_strategy = [[t, "t", 1] for t in range(1, 4, timesteps_per_intervention)]
outcome = "outcome"
fit_bl_switch_formula = "xo_t_do ~ time"
df = pd.read_csv("tests/resources/data/temporal_data.csv")
Expand All @@ -38,8 +37,8 @@ def test_estimate_hazard_ratio(self):

def test_invalid_treatment_strategies(self):
timesteps_per_intervention = 1
control_strategy = TreatmentSequence(timesteps_per_intervention, [("t", 0), ("t", 0), ("t", 0)])
treatment_strategy = TreatmentSequence(timesteps_per_intervention, [("t", 1), ("t", 1), ("t", 1)])
control_strategy = [[t, "t", 0] for t in range(1, 4, timesteps_per_intervention)]
treatment_strategy = [[t, "t", 1] for t in range(1, 4, timesteps_per_intervention)]
outcome = "outcome"
fit_bl_switch_formula = "xo_t_do ~ time"
df = pd.read_csv("tests/resources/data/temporal_data.csv")
Expand All @@ -60,8 +59,8 @@ def test_invalid_treatment_strategies(self):

def test_invalid_fault_t_do(self):
timesteps_per_intervention = 1
control_strategy = TreatmentSequence(timesteps_per_intervention, [("t", 0), ("t", 0), ("t", 0)])
treatment_strategy = TreatmentSequence(timesteps_per_intervention, [("t", 1), ("t", 1), ("t", 1)])
control_strategy = [[t, "t", 0] for t in range(1, 4, timesteps_per_intervention)]
treatment_strategy = [[t, "t", 1] for t in range(1, 4, timesteps_per_intervention)]
outcome = "outcome"
fit_bl_switch_formula = "xo_t_do ~ time"
df = pd.read_csv("tests/resources/data/temporal_data.csv")
Expand Down
1 change: 0 additions & 1 deletion tests/estimation_tests/test_linear_regression_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import matplotlib.pyplot as plt
from causal_testing.specification.variable import Input
from causal_testing.utils.validation import CausalValidator
from causal_testing.specification.capabilities import TreatmentSequence

from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
from causal_testing.estimation.genetic_programming_regression_fitter import reciprocal
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import matplotlib.pyplot as plt
from causal_testing.specification.variable import Input
from causal_testing.utils.validation import CausalValidator
from causal_testing.specification.capabilities import TreatmentSequence
from causal_testing.estimation.logistic_regression_estimator import LogisticRegressionEstimator


Expand Down
36 changes: 0 additions & 36 deletions tests/specification_tests/test_capabilities.py

This file was deleted.

5 changes: 2 additions & 3 deletions tests/testing_tests/test_causal_test_adequacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from causal_testing.specification.variable import Input, Output, Meta
from causal_testing.specification.scenario import Scenario
from causal_testing.specification.causal_specification import CausalSpecification
from causal_testing.specification.capabilities import TreatmentSequence
from causal_testing.testing.causal_test_adequacy import DataAdequacy


Expand Down Expand Up @@ -112,8 +111,8 @@ def test_data_adequacy_cateogorical(self):

def test_data_adequacy_group_by(self):
timesteps_per_intervention = 1
control_strategy = TreatmentSequence(timesteps_per_intervention, [("t", 0), ("t", 0), ("t", 0)])
treatment_strategy = TreatmentSequence(timesteps_per_intervention, [("t", 1), ("t", 1), ("t", 1)])
control_strategy = [[t, "t", 0] for t in range(1, 4, timesteps_per_intervention)]
treatment_strategy = [[t, "t", 1] for t in range(1, 4, timesteps_per_intervention)]
outcome = "outcome"
fit_bl_switch_formula = "xo_t_do ~ time"
df = pd.read_csv("tests/resources/data/temporal_data.csv")
Expand Down
Loading