Skip to content

Commit de0a676

Browse files
authored
Merge pull request #292 from CITCOM-project/sparse-treatment-strategies
Support for treatment strategies with interventions at non-consecutiv…
2 parents d5f0dad + e6d46cd commit de0a676

File tree

10 files changed

+293
-245
lines changed

10 files changed

+293
-245
lines changed

causal_testing/estimation/ipcw_estimator.py

Lines changed: 226 additions & 92 deletions
Large diffs are not rendered by default.

causal_testing/specification/capabilities.py

Lines changed: 0 additions & 77 deletions
This file was deleted.

causal_testing/testing/causal_test_case.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import logging
44
from typing import Any
5-
import numpy as np
65

76
from causal_testing.specification.variable import Variable
87
from causal_testing.testing.causal_test_outcome import CausalTestOutcome
@@ -81,21 +80,13 @@ def _return_causal_test_results(self, estimator) -> CausalTestResult:
8180
if not hasattr(estimator, f"estimate_{self.estimate_type}"):
8281
raise AttributeError(f"{estimator.__class__} has no {self.estimate_type} method.")
8382
estimate_effect = getattr(estimator, f"estimate_{self.estimate_type}")
84-
try:
85-
effect, confidence_intervals = estimate_effect(**self.estimate_params)
86-
return CausalTestResult(
87-
estimator=estimator,
88-
test_value=TestValue(self.estimate_type, effect),
89-
effect_modifier_configuration=self.effect_modifier_configuration,
90-
confidence_intervals=confidence_intervals,
91-
)
92-
except np.linalg.LinAlgError:
93-
return CausalTestResult(
94-
estimator=estimator,
95-
test_value=TestValue(self.estimate_type, None),
96-
effect_modifier_configuration=self.effect_modifier_configuration,
97-
confidence_intervals=None,
98-
)
83+
effect, confidence_intervals = estimate_effect(**self.estimate_params)
84+
return CausalTestResult(
85+
estimator=estimator,
86+
test_value=TestValue(self.estimate_type, effect),
87+
effect_modifier_configuration=self.effect_modifier_configuration,
88+
confidence_intervals=confidence_intervals,
89+
)
9990

10091
def __str__(self):
10192
treatment_config = {self.treatment_variable.name: self.treatment_value}

tests/estimation_tests/test_cubic_spline_estimator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import matplotlib.pyplot as plt
55
from causal_testing.specification.variable import Input
66
from causal_testing.utils.validation import CausalValidator
7-
from causal_testing.specification.capabilities import TreatmentSequence
87

98
from causal_testing.estimation.cubic_spline_estimator import CubicSplineRegressionEstimator
109
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator

tests/estimation_tests/test_instrumental_variable_estimator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import matplotlib.pyplot as plt
55
from causal_testing.specification.variable import Input
66
from causal_testing.utils.validation import CausalValidator
7-
from causal_testing.specification.capabilities import TreatmentSequence
87

98
from causal_testing.estimation.instrumental_variable_estimator import InstrumentalVariableEstimator
109

tests/estimation_tests/test_ipcw_estimator.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import matplotlib.pyplot as plt
55
from causal_testing.specification.variable import Input
66
from causal_testing.utils.validation import CausalValidator
7-
from causal_testing.specification.capabilities import TreatmentSequence
87

98
from causal_testing.estimation.ipcw_estimator import IPCWEstimator
109

@@ -16,8 +15,8 @@ class TestIPCWEstimator(unittest.TestCase):
1615

1716
def test_estimate_hazard_ratio(self):
1817
timesteps_per_intervention = 1
19-
control_strategy = TreatmentSequence(timesteps_per_intervention, [("t", 0), ("t", 0), ("t", 0)])
20-
treatment_strategy = TreatmentSequence(timesteps_per_intervention, [("t", 1), ("t", 1), ("t", 1)])
18+
control_strategy = [[t, "t", 0] for t in range(1, 4, timesteps_per_intervention)]
19+
treatment_strategy = [[t, "t", 1] for t in range(1, 4, timesteps_per_intervention)]
2120
outcome = "outcome"
2221
fit_bl_switch_formula = "xo_t_do ~ time"
2322
df = pd.read_csv("tests/resources/data/temporal_data.csv")
@@ -34,12 +33,12 @@ def test_estimate_hazard_ratio(self):
3433
eligibility=None,
3534
)
3635
estimate, intervals = estimation_model.estimate_hazard_ratio()
37-
self.assertEqual(estimate["trtrand"], 1.0)
36+
self.assertEqual(round(estimate["trtrand"], 3), 1.351)
3837

3938
def test_invalid_treatment_strategies(self):
4039
timesteps_per_intervention = 1
41-
control_strategy = TreatmentSequence(timesteps_per_intervention, [("t", 0), ("t", 0), ("t", 0)])
42-
treatment_strategy = TreatmentSequence(timesteps_per_intervention, [("t", 1), ("t", 1), ("t", 1)])
40+
control_strategy = [[t, "t", 0] for t in range(1, 4, timesteps_per_intervention)]
41+
treatment_strategy = [[t, "t", 1] for t in range(1, 4, timesteps_per_intervention)]
4342
outcome = "outcome"
4443
fit_bl_switch_formula = "xo_t_do ~ time"
4544
df = pd.read_csv("tests/resources/data/temporal_data.csv")
@@ -60,8 +59,8 @@ def test_invalid_treatment_strategies(self):
6059

6160
def test_invalid_fault_t_do(self):
6261
timesteps_per_intervention = 1
63-
control_strategy = TreatmentSequence(timesteps_per_intervention, [("t", 0), ("t", 0), ("t", 0)])
64-
treatment_strategy = TreatmentSequence(timesteps_per_intervention, [("t", 1), ("t", 1), ("t", 1)])
62+
control_strategy = [[t, "t", 0] for t in range(1, 4, timesteps_per_intervention)]
63+
treatment_strategy = [[t, "t", 1] for t in range(1, 4, timesteps_per_intervention)]
6564
outcome = "outcome"
6665
fit_bl_switch_formula = "xo_t_do ~ time"
6766
df = pd.read_csv("tests/resources/data/temporal_data.csv")
@@ -80,3 +79,47 @@ def test_invalid_fault_t_do(self):
8079
estimation_model.df["fault_t_do"] = 0
8180
with self.assertRaises(ValueError):
8281
estimate, intervals = estimation_model.estimate_hazard_ratio()
82+
83+
def test_no_individual_began_control_strategy(self):
84+
timesteps_per_intervention = 1
85+
control_strategy = [[t, "t", 0] for t in range(1, 4, timesteps_per_intervention)]
86+
treatment_strategy = [[t, "t", 1] for t in range(1, 4, timesteps_per_intervention)]
87+
outcome = "outcome"
88+
fit_bl_switch_formula = "xo_t_do ~ time"
89+
df = pd.read_csv("tests/resources/data/temporal_data.csv")
90+
df["t"] = 1
91+
df["ok"] = df["outcome"] == 1
92+
with self.assertRaises(ValueError):
93+
estimation_model = IPCWEstimator(
94+
df,
95+
timesteps_per_intervention,
96+
control_strategy,
97+
treatment_strategy,
98+
outcome,
99+
"ok",
100+
fit_bl_switch_formula=fit_bl_switch_formula,
101+
fit_bltd_switch_formula=fit_bl_switch_formula,
102+
eligibility=None,
103+
)
104+
105+
def test_no_individual_began_treatment_strategy(self):
106+
timesteps_per_intervention = 1
107+
control_strategy = [[t, "t", 0] for t in range(1, 4, timesteps_per_intervention)]
108+
treatment_strategy = [[t, "t", 1] for t in range(1, 4, timesteps_per_intervention)]
109+
outcome = "outcome"
110+
fit_bl_switch_formula = "xo_t_do ~ time"
111+
df = pd.read_csv("tests/resources/data/temporal_data.csv")
112+
df["t"] = 0
113+
df["ok"] = df["outcome"] == 1
114+
with self.assertRaises(ValueError):
115+
estimation_model = IPCWEstimator(
116+
df,
117+
timesteps_per_intervention,
118+
control_strategy,
119+
treatment_strategy,
120+
outcome,
121+
"ok",
122+
fit_bl_switch_formula=fit_bl_switch_formula,
123+
fit_bltd_switch_formula=fit_bl_switch_formula,
124+
eligibility=None,
125+
)

tests/estimation_tests/test_linear_regression_estimator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import matplotlib.pyplot as plt
55
from causal_testing.specification.variable import Input
66
from causal_testing.utils.validation import CausalValidator
7-
from causal_testing.specification.capabilities import TreatmentSequence
87

98
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
109
from causal_testing.estimation.genetic_programming_regression_fitter import reciprocal

tests/estimation_tests/test_logistic_regression_estimator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import matplotlib.pyplot as plt
55
from causal_testing.specification.variable import Input
66
from causal_testing.utils.validation import CausalValidator
7-
from causal_testing.specification.capabilities import TreatmentSequence
87
from causal_testing.estimation.logistic_regression_estimator import LogisticRegressionEstimator
98

109

tests/specification_tests/test_capabilities.py

Lines changed: 0 additions & 36 deletions
This file was deleted.

tests/testing_tests/test_causal_test_adequacy.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1+
import os
12
import unittest
23
from pathlib import Path
3-
from statistics import StatisticsError
44
import scipy
5-
import os
65
import pandas as pd
76

87
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
@@ -11,12 +10,9 @@
1110
from causal_testing.testing.causal_test_case import CausalTestCase
1211
from causal_testing.testing.causal_test_suite import CausalTestSuite
1312
from causal_testing.testing.causal_test_adequacy import DAGAdequacy
14-
from causal_testing.testing.causal_test_outcome import NoEffect, Positive, SomeEffect
13+
from causal_testing.testing.causal_test_outcome import NoEffect, SomeEffect
1514
from causal_testing.json_front.json_class import JsonUtility, CausalVariables
16-
from causal_testing.specification.variable import Input, Output, Meta
1715
from causal_testing.specification.scenario import Scenario
18-
from causal_testing.specification.causal_specification import CausalSpecification
19-
from causal_testing.specification.capabilities import TreatmentSequence
2016
from causal_testing.testing.causal_test_adequacy import DataAdequacy
2117

2218

@@ -112,8 +108,8 @@ def test_data_adequacy_cateogorical(self):
112108

113109
def test_data_adequacy_group_by(self):
114110
timesteps_per_intervention = 1
115-
control_strategy = TreatmentSequence(timesteps_per_intervention, [("t", 0), ("t", 0), ("t", 0)])
116-
treatment_strategy = TreatmentSequence(timesteps_per_intervention, [("t", 1), ("t", 1), ("t", 1)])
111+
control_strategy = [[t, "t", 0] for t in range(1, 4, timesteps_per_intervention)]
112+
treatment_strategy = [[t, "t", 1] for t in range(1, 4, timesteps_per_intervention)]
117113
outcome = "outcome"
118114
fit_bl_switch_formula = "xo_t_do ~ time"
119115
df = pd.read_csv("tests/resources/data/temporal_data.csv")
@@ -145,11 +141,12 @@ def test_data_adequacy_group_by(self):
145141
causal_test_result = causal_test_case.execute_test(estimation_model, None)
146142
adequacy_metric = DataAdequacy(causal_test_case, estimation_model, group_by="id")
147143
adequacy_metric.measure_adequacy()
148-
causal_test_result.adequacy = adequacy_metric
149-
print(causal_test_result.adequacy.to_dict())
144+
adequacy_dict = adequacy_metric.to_dict()
145+
self.assertEqual(round(adequacy_dict["kurtosis"]["trtrand"], 3), -0.857)
146+
adequacy_dict.pop("kurtosis")
150147
self.assertEqual(
151-
causal_test_result.adequacy.to_dict(),
152-
{"kurtosis": {"trtrand": 0.0}, "bootstrap_size": 100, "passing": 0, "successful": 95},
148+
adequacy_dict,
149+
{"bootstrap_size": 100, "passing": 32, "successful": 100},
153150
)
154151

155152
def test_dag_adequacy_dependent(self):

0 commit comments

Comments
 (0)