Skip to content

Commit 4d48785

Browse files
committed
IPCW outcome is now an output
1 parent b86c584 commit 4d48785

File tree

4 files changed

+65
-97
lines changed

4 files changed

+65
-97
lines changed

causal_testing/estimation/ipcw_estimator.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from causal_testing.estimation.abstract_estimator import Estimator
1414
from causal_testing.testing.base_test_case import BaseTestCase
15-
from causal_testing.specification.variable import Input, Output
15+
from causal_testing.specification.variable import Variable
1616

1717
logger = logging.getLogger(__name__)
1818

@@ -32,7 +32,7 @@ def __init__(
3232
timesteps_per_observation: int,
3333
control_strategy: list[tuple[int, str, Any]],
3434
treatment_strategy: list[tuple[int, str, Any]],
35-
outcome: str,
35+
outcome: Variable,
3636
status_column: str,
3737
fit_bl_switch_formula: str,
3838
fit_bltd_switch_formula: str,
@@ -58,7 +58,7 @@ def __init__(
5858
treatment) with the most elements multiplied by `timesteps_per_observation`.
5959
"""
6060
super().__init__(
61-
base_test_case=BaseTestCase(Input("_", float), Output(outcome, float)),
61+
base_test_case=BaseTestCase(None, outcome),
6262
treatment_value=[val for _, _, val in treatment_strategy],
6363
control_value=[val for _, _, val in control_strategy],
6464
adjustment_set=None,
@@ -70,7 +70,6 @@ def __init__(
7070
self.timesteps_per_observation = timesteps_per_observation
7171
self.control_strategy = control_strategy
7272
self.treatment_strategy = treatment_strategy
73-
self.outcome = outcome
7473
self.status_column = status_column
7574
self.fit_bl_switch_formula = fit_bl_switch_formula
7675
self.fit_bltd_switch_formula = fit_bltd_switch_formula

causal_testing/testing/causal_test_result.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,12 @@ def to_dict(self, json=False):
8484
"""Return result contents as a dictionary
8585
:return: Dictionary containing contents of causal_test_result
8686
"""
87-
if isinstance(self.estimator.base_test_case.treatment_variable, list):
88-
treatment = [x.name for x in self.estimator.base_test_case.treatment_variable]
89-
else:
90-
treatment = self.estimator.base_test_case.treatment_variable.name
9187
base_dict = {
92-
"treatment": treatment,
88+
"treatment": (
89+
self.estimator.base_test_case.treatment_variable.name
90+
if self.estimator.base_test_case.treatment_variable is not None
91+
else None
92+
),
9393
"control_value": self.estimator.control_value,
9494
"treatment_value": self.estimator.treatment_value,
9595
"outcome": self.estimator.base_test_case.outcome_variable.name,
Lines changed: 56 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
import unittest
22
import pandas as pd
3-
import numpy as np
4-
import matplotlib.pyplot as plt
5-
from causal_testing.specification.variable import Input
6-
from causal_testing.utils.validation import CausalValidator
3+
from causal_testing.specification.variable import Input, Output
74

85
from causal_testing.estimation.ipcw_estimator import IPCWEstimator
96

@@ -13,113 +10,85 @@ class TestIPCWEstimator(unittest.TestCase):
1310
Test the IPCW estimator class
1411
"""
1512

13+
def setUp(self) -> None:
14+
self.outcome = Output("outcome", float)
15+
self.status_column = "ok"
16+
self.timesteps_per_intervention = 1
17+
self.control_strategy = [[t, "t", 0] for t in range(1, 4, self.timesteps_per_intervention)]
18+
self.treatment_strategy = [[t, "t", 1] for t in range(1, 4, self.timesteps_per_intervention)]
19+
self.fit_bl_switch_formula = "xo_t_do ~ time"
20+
self.df = pd.read_csv("tests/resources/data/temporal_data.csv")
21+
self.df[self.status_column] = self.df["outcome"] == 1
22+
1623
def test_estimate_hazard_ratio(self):
17-
timesteps_per_intervention = 1
18-
control_strategy = [[t, "t", 0] for t in range(1, 4, timesteps_per_intervention)]
19-
treatment_strategy = [[t, "t", 1] for t in range(1, 4, timesteps_per_intervention)]
20-
outcome = "outcome"
21-
fit_bl_switch_formula = "xo_t_do ~ time"
22-
df = pd.read_csv("tests/resources/data/temporal_data.csv")
23-
df["ok"] = df["outcome"] == 1
2424
estimation_model = IPCWEstimator(
25-
df,
26-
timesteps_per_intervention,
27-
control_strategy,
28-
treatment_strategy,
29-
outcome,
30-
"ok",
31-
fit_bl_switch_formula=fit_bl_switch_formula,
32-
fit_bltd_switch_formula=fit_bl_switch_formula,
25+
self.df,
26+
self.timesteps_per_intervention,
27+
self.control_strategy,
28+
self.treatment_strategy,
29+
self.outcome,
30+
self.status_column,
31+
fit_bl_switch_formula=self.fit_bl_switch_formula,
32+
fit_bltd_switch_formula=self.fit_bl_switch_formula,
3333
eligibility=None,
3434
)
35-
estimate, intervals = estimation_model.estimate_hazard_ratio()
35+
estimate, _ = estimation_model.estimate_hazard_ratio()
3636
self.assertEqual(round(estimate["trtrand"], 3), 1.351)
3737

3838
def test_invalid_treatment_strategies(self):
39-
timesteps_per_intervention = 1
40-
control_strategy = [[t, "t", 0] for t in range(1, 4, timesteps_per_intervention)]
41-
treatment_strategy = [[t, "t", 1] for t in range(1, 4, timesteps_per_intervention)]
42-
outcome = "outcome"
43-
fit_bl_switch_formula = "xo_t_do ~ time"
44-
df = pd.read_csv("tests/resources/data/temporal_data.csv")
45-
df["t"] = (["1", "0"] * len(df))[: len(df)]
46-
df["ok"] = df["outcome"] == 1
4739
with self.assertRaises(ValueError):
48-
estimation_model = IPCWEstimator(
49-
df,
50-
timesteps_per_intervention,
51-
control_strategy,
52-
treatment_strategy,
53-
outcome,
54-
"ok",
55-
fit_bl_switch_formula=fit_bl_switch_formula,
56-
fit_bltd_switch_formula=fit_bl_switch_formula,
40+
IPCWEstimator(
41+
self.df.assign(t=(["1", "0"] * len(self.df))[: len(self.df)]),
42+
self.timesteps_per_intervention,
43+
self.control_strategy,
44+
self.treatment_strategy,
45+
self.outcome,
46+
self.status_column,
47+
fit_bl_switch_formula=self.fit_bl_switch_formula,
48+
fit_bltd_switch_formula=self.fit_bl_switch_formula,
5749
eligibility=None,
5850
)
5951

6052
def test_invalid_fault_t_do(self):
61-
timesteps_per_intervention = 1
62-
control_strategy = [[t, "t", 0] for t in range(1, 4, timesteps_per_intervention)]
63-
treatment_strategy = [[t, "t", 1] for t in range(1, 4, timesteps_per_intervention)]
64-
outcome = "outcome"
65-
fit_bl_switch_formula = "xo_t_do ~ time"
66-
df = pd.read_csv("tests/resources/data/temporal_data.csv")
67-
df["ok"] = df["outcome"] == 1
6853
estimation_model = IPCWEstimator(
69-
df,
70-
timesteps_per_intervention,
71-
control_strategy,
72-
treatment_strategy,
73-
outcome,
74-
"ok",
75-
fit_bl_switch_formula=fit_bl_switch_formula,
76-
fit_bltd_switch_formula=fit_bl_switch_formula,
54+
self.df.assign(outcome=1),
55+
self.timesteps_per_intervention,
56+
self.control_strategy,
57+
self.treatment_strategy,
58+
self.outcome,
59+
self.status_column,
60+
fit_bl_switch_formula=self.fit_bl_switch_formula,
61+
fit_bltd_switch_formula=self.fit_bl_switch_formula,
7762
eligibility=None,
7863
)
7964
estimation_model.df["fault_t_do"] = 0
8065
with self.assertRaises(ValueError):
81-
estimate, intervals = estimation_model.estimate_hazard_ratio()
66+
estimation_model.estimate_hazard_ratio()
8267

8368
def test_no_individual_began_control_strategy(self):
84-
timesteps_per_intervention = 1
85-
control_strategy = [[t, "t", 0] for t in range(1, 4, timesteps_per_intervention)]
86-
treatment_strategy = [[t, "t", 1] for t in range(1, 4, timesteps_per_intervention)]
87-
outcome = "outcome"
88-
fit_bl_switch_formula = "xo_t_do ~ time"
89-
df = pd.read_csv("tests/resources/data/temporal_data.csv")
90-
df["t"] = 1
91-
df["ok"] = df["outcome"] == 1
9269
with self.assertRaises(ValueError):
93-
estimation_model = IPCWEstimator(
94-
df,
95-
timesteps_per_intervention,
96-
control_strategy,
97-
treatment_strategy,
98-
outcome,
99-
"ok",
100-
fit_bl_switch_formula=fit_bl_switch_formula,
101-
fit_bltd_switch_formula=fit_bl_switch_formula,
70+
IPCWEstimator(
71+
self.df.assign(t=1),
72+
self.timesteps_per_intervention,
73+
self.control_strategy,
74+
self.treatment_strategy,
75+
self.outcome,
76+
self.status_column,
77+
fit_bl_switch_formula=self.fit_bl_switch_formula,
78+
fit_bltd_switch_formula=self.fit_bl_switch_formula,
10279
eligibility=None,
10380
)
10481

10582
def test_no_individual_began_treatment_strategy(self):
106-
timesteps_per_intervention = 1
107-
control_strategy = [[t, "t", 0] for t in range(1, 4, timesteps_per_intervention)]
108-
treatment_strategy = [[t, "t", 1] for t in range(1, 4, timesteps_per_intervention)]
109-
outcome = "outcome"
110-
fit_bl_switch_formula = "xo_t_do ~ time"
111-
df = pd.read_csv("tests/resources/data/temporal_data.csv")
112-
df["t"] = 0
113-
df["ok"] = df["outcome"] == 1
11483
with self.assertRaises(ValueError):
115-
estimation_model = IPCWEstimator(
116-
df,
117-
timesteps_per_intervention,
118-
control_strategy,
119-
treatment_strategy,
120-
outcome,
121-
"ok",
122-
fit_bl_switch_formula=fit_bl_switch_formula,
123-
fit_bltd_switch_formula=fit_bl_switch_formula,
84+
IPCWEstimator(
85+
self.df.assign(t=0),
86+
self.timesteps_per_intervention,
87+
self.control_strategy,
88+
self.treatment_strategy,
89+
self.outcome,
90+
self.status_column,
91+
fit_bl_switch_formula=self.fit_bl_switch_formula,
92+
fit_bltd_switch_formula=self.fit_bl_switch_formula,
12493
eligibility=None,
12594
)

tests/testing_tests/test_causal_test_adequacy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def test_data_adequacy_group_by(self):
9292
fit_bltd_switch_formula=fit_bl_switch_formula,
9393
eligibility=None,
9494
)
95-
base_test_case = estimation_model.base_test_case
95+
base_test_case = BaseTestCase(Input("t", float), Output("outcome", float))
9696

9797
causal_test_case = CausalTestCase(
9898
base_test_case=base_test_case,

0 commit comments

Comments
 (0)