Skip to content

Commit b86c584

Browse files
committed
All the tests pass and got rid of JSON front
1 parent 77ac1b8 commit b86c584

26 files changed

+289
-944
lines changed

causal_testing/estimation/abstract_estimator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import pandas as pd
88

9+
from causal_testing.testing.base_test_case import BaseTestCase
10+
911
logger = logging.getLogger(__name__)
1012

1113

@@ -30,21 +32,19 @@ class Estimator(ABC):
3032
def __init__(
3133
# pylint: disable=too-many-arguments
3234
self,
33-
treatment: str,
35+
base_test_case: BaseTestCase,
3436
treatment_value: float,
3537
control_value: float,
3638
adjustment_set: set,
37-
outcome: str,
3839
df: pd.DataFrame = None,
3940
effect_modifiers: dict[str:Any] = None,
4041
alpha: float = 0.05,
4142
query: str = "",
4243
):
43-
self.treatment = treatment
44+
self.base_test_case = base_test_case
4445
self.treatment_value = treatment_value
4546
self.control_value = control_value
4647
self.adjustment_set = adjustment_set
47-
self.outcome = outcome
4848
self.alpha = alpha
4949
self.df = df.query(query) if query else df
5050

causal_testing/estimation/abstract_regression_estimator.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from causal_testing.specification.variable import Variable
1212
from causal_testing.estimation.abstract_estimator import Estimator
13+
from causal_testing.testing.base_test_case import BaseTestCase
1314

1415
logger = logging.getLogger(__name__)
1516

@@ -22,23 +23,21 @@ class RegressionEstimator(Estimator):
2223
def __init__(
2324
# pylint: disable=too-many-arguments
2425
self,
25-
treatment: str,
26+
base_test_case: BaseTestCase,
2627
treatment_value: float,
2728
control_value: float,
2829
adjustment_set: set,
29-
outcome: str,
3030
df: pd.DataFrame = None,
3131
effect_modifiers: dict[Variable:Any] = None,
3232
formula: str = None,
3333
alpha: float = 0.05,
3434
query: str = "",
3535
):
3636
super().__init__(
37-
treatment=treatment,
37+
base_test_case=base_test_case,
3838
treatment_value=treatment_value,
3939
control_value=control_value,
4040
adjustment_set=adjustment_set,
41-
outcome=outcome,
4241
df=df,
4342
effect_modifiers=effect_modifiers,
4443
alpha=alpha,
@@ -53,8 +52,10 @@ def __init__(
5352
if formula is not None:
5453
self.formula = formula
5554
else:
56-
terms = [treatment] + sorted(list(adjustment_set)) + sorted(list(effect_modifiers))
57-
self.formula = f"{outcome} ~ {'+'.join(terms)}"
55+
terms = (
56+
[base_test_case.treatment_variable.name] + sorted(list(adjustment_set)) + sorted(list(effect_modifiers))
57+
)
58+
self.formula = f"{base_test_case.outcome_variable.name} ~ {'+'.join(terms)}"
5859

5960
@property
6061
@abstractmethod
@@ -104,7 +105,7 @@ def _predict(self, data=None, adjustment_config: dict = None) -> pd.DataFrame:
104105

105106
x = pd.DataFrame(columns=self.df.columns)
106107
x["Intercept"] = 1 # self.intercept
107-
x[self.treatment] = [self.treatment_value, self.control_value]
108+
x[self.base_test_case.treatment_variable.name] = [self.treatment_value, self.control_value]
108109

109110
for k, v in adjustment_config.items():
110111
x[k] = v
@@ -116,5 +117,5 @@ def _predict(self, data=None, adjustment_config: dict = None) -> pd.DataFrame:
116117
x = pd.get_dummies(x, columns=[col], drop_first=True)
117118

118119
# This has to be here in case the treatment variable is in an I(...) block in the self.formula
119-
x[self.treatment] = [self.treatment_value, self.control_value]
120+
x[self.base_test_case.treatment_variable.name] = [self.treatment_value, self.control_value]
120121
return model.get_prediction(x).summary_frame()

causal_testing/estimation/cubic_spline_estimator.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from causal_testing.specification.variable import Variable
1010
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
11+
from causal_testing.testing.base_test_case import BaseTestCase
1112

1213
logger = logging.getLogger(__name__)
1314

@@ -20,11 +21,10 @@ class CubicSplineRegressionEstimator(LinearRegressionEstimator):
2021
def __init__(
2122
# pylint: disable=too-many-arguments
2223
self,
23-
treatment: str,
24+
base_test_case: BaseTestCase,
2425
treatment_value: float,
2526
control_value: float,
2627
adjustment_set: set,
27-
outcome: str,
2828
basis: int,
2929
df: pd.DataFrame = None,
3030
effect_modifiers: dict[Variable:Any] = None,
@@ -33,7 +33,7 @@ def __init__(
3333
expected_relationship=None,
3434
):
3535
super().__init__(
36-
treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers, formula, alpha
36+
base_test_case, treatment_value, control_value, adjustment_set, df, effect_modifiers, formula, alpha
3737
)
3838

3939
self.expected_relationship = expected_relationship
@@ -42,8 +42,10 @@ def __init__(
4242
effect_modifiers = []
4343

4444
if formula is None:
45-
terms = [treatment] + sorted(list(adjustment_set)) + sorted(list(effect_modifiers))
46-
self.formula = f"{outcome} ~ cr({'+'.join(terms)}, df={basis})"
45+
terms = (
46+
[base_test_case.treatment_variable.name] + sorted(list(adjustment_set)) + sorted(list(effect_modifiers))
47+
)
48+
self.formula = f"{base_test_case.outcome_variable.name} ~ cr({'+'.join(terms)}, df={basis})"
4749

4850
def estimate_ate_calculated(self, adjustment_config: dict = None) -> pd.Series:
4951
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
@@ -59,7 +61,7 @@ def estimate_ate_calculated(self, adjustment_config: dict = None) -> pd.Series:
5961
"""
6062
model = self._run_regression()
6163

62-
x = {"Intercept": 1, self.treatment: self.treatment_value}
64+
x = {"Intercept": 1, self.base_test_case.treatment_variable.name: self.treatment_value}
6365
if adjustment_config is not None:
6466
for k, v in adjustment_config.items():
6567
x[k] = v
@@ -69,7 +71,7 @@ def estimate_ate_calculated(self, adjustment_config: dict = None) -> pd.Series:
6971

7072
treatment = model.predict(x).iloc[0]
7173

72-
x[self.treatment] = self.control_value
74+
x[self.base_test_case.treatment_variable.name] = self.control_value
7375
control = model.predict(x).iloc[0]
7476

7577
return pd.Series(treatment - control)

causal_testing/estimation/experimental_estimator.py

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pandas as pd
66

77
from causal_testing.estimation.abstract_estimator import Estimator
8+
from causal_testing.testing.base_test_case import BaseTestCase
89

910

1011
class ExperimentalEstimator(Estimator):
@@ -16,22 +17,20 @@ class ExperimentalEstimator(Estimator):
1617
def __init__(
1718
# pylint: disable=too-many-arguments
1819
self,
19-
treatment: str,
20+
base_test_case: BaseTestCase,
2021
treatment_value: float,
2122
control_value: float,
2223
adjustment_set: dict[str:Any],
23-
outcome: str,
2424
effect_modifiers: dict[str:Any] = None,
2525
alpha: float = 0.05,
2626
repeats: int = 200,
2727
):
2828
# pylint: disable=R0801
2929
super().__init__(
30-
treatment=treatment,
30+
base_test_case=base_test_case,
3131
treatment_value=treatment_value,
3232
control_value=control_value,
3333
adjustment_set=adjustment_set,
34-
outcome=outcome,
3534
effect_modifiers=effect_modifiers,
3635
alpha=alpha,
3736
)
@@ -62,21 +61,40 @@ def estimate_ate(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
6261
6362
:return: The average treatment effect and the bootstrapped confidence intervals.
6463
"""
65-
control_configuration = self.adjustment_set | self.effect_modifiers | {self.treatment: self.control_value}
66-
treatment_configuration = self.adjustment_set | self.effect_modifiers | {self.treatment: self.treatment_value}
64+
control_configuration = (
65+
self.adjustment_set
66+
| self.effect_modifiers
67+
| {self.base_test_case.treatment_variable.name: self.control_value}
68+
)
69+
treatment_configuration = (
70+
self.adjustment_set
71+
| self.effect_modifiers
72+
| {self.base_test_case.treatment_variable.name: self.treatment_value}
73+
)
6774

6875
control_outcomes = pd.DataFrame([self.run_system(control_configuration) for _ in range(self.repeats)])
6976
treatment_outcomes = pd.DataFrame([self.run_system(treatment_configuration) for _ in range(self.repeats)])
7077

71-
difference = (treatment_outcomes[self.outcome] - control_outcomes[self.outcome]).sort_values().reset_index()
78+
difference = (
79+
(
80+
treatment_outcomes[self.base_test_case.outcome_variable.name]
81+
- control_outcomes[self.base_test_case.outcome_variable.name]
82+
)
83+
.sort_values()
84+
.reset_index()
85+
)
7286

7387
ci_low_index = round(self.repeats * (self.alpha / 2))
7488
ci_low = difference.iloc[ci_low_index]
7589
ci_high = difference.iloc[self.repeats - ci_low_index]
7690

77-
return pd.Series({self.treatment: difference.mean()[self.outcome]}), [
78-
pd.Series({self.treatment: ci_low[self.outcome]}),
79-
pd.Series({self.treatment: ci_high[self.outcome]}),
91+
return pd.Series(
92+
{self.base_test_case.treatment_variable.name: difference.mean()[self.base_test_case.outcome_variable.name]}
93+
), [
94+
pd.Series({self.base_test_case.treatment_variable.name: ci_low[self.base_test_case.outcome_variable.name]}),
95+
pd.Series(
96+
{self.base_test_case.treatment_variable.name: ci_high[self.base_test_case.outcome_variable.name]}
97+
),
8098
]
8199

82100
def estimate_risk_ratio(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
@@ -85,19 +103,38 @@ def estimate_risk_ratio(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
85103
86104
:return: The average treatment effect and the bootstrapped confidence intervals.
87105
"""
88-
control_configuration = self.adjustment_set | self.effect_modifiers | {self.treatment: self.control_value}
89-
treatment_configuration = self.adjustment_set | self.effect_modifiers | {self.treatment: self.treatment_value}
106+
control_configuration = (
107+
self.adjustment_set
108+
| self.effect_modifiers
109+
| {self.base_test_case.treatment_variable.name: self.control_value}
110+
)
111+
treatment_configuration = (
112+
self.adjustment_set
113+
| self.effect_modifiers
114+
| {self.base_test_case.treatment_variable.name: self.treatment_value}
115+
)
90116

91117
control_outcomes = pd.DataFrame([self.run_system(control_configuration) for _ in range(self.repeats)])
92118
treatment_outcomes = pd.DataFrame([self.run_system(treatment_configuration) for _ in range(self.repeats)])
93119

94-
difference = (treatment_outcomes[self.outcome] / control_outcomes[self.outcome]).sort_values().reset_index()
120+
difference = (
121+
(
122+
treatment_outcomes[self.base_test_case.outcome_variable.name]
123+
/ control_outcomes[self.base_test_case.outcome_variable.name]
124+
)
125+
.sort_values()
126+
.reset_index()
127+
)
95128

96129
ci_low_index = round(self.repeats * (self.alpha / 2))
97130
ci_low = difference.iloc[ci_low_index]
98131
ci_high = difference.iloc[self.repeats - ci_low_index]
99132

100-
return pd.Series({self.treatment: difference.mean()[self.outcome]}), [
101-
pd.Series({self.treatment: ci_low[self.outcome]}),
102-
pd.Series({self.treatment: ci_high[self.outcome]}),
133+
return pd.Series(
134+
{self.base_test_case.treatment_variable.name: difference.mean()[self.base_test_case.outcome_variable.name]}
135+
), [
136+
pd.Series({self.base_test_case.treatment_variable.name: ci_low[self.base_test_case.outcome_variable.name]}),
137+
pd.Series(
138+
{self.base_test_case.treatment_variable.name: ci_high[self.base_test_case.outcome_variable.name]}
139+
),
103140
]

causal_testing/estimation/instrumental_variable_estimator.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import statsmodels.api as sm
88

99
from causal_testing.estimation.abstract_estimator import Estimator
10+
from causal_testing.testing.base_test_case import BaseTestCase
1011

1112
logger = logging.getLogger(__name__)
1213

@@ -21,22 +22,20 @@ def __init__(
2122
# pylint: disable=too-many-arguments
2223
# pylint: disable=duplicate-code
2324
self,
24-
treatment: str,
25+
base_test_case: BaseTestCase,
2526
treatment_value: float,
2627
control_value: float,
2728
adjustment_set: set,
28-
outcome: str,
2929
instrument: str,
3030
df: pd.DataFrame = None,
3131
alpha: float = 0.05,
3232
query: str = "",
3333
):
3434
super().__init__(
35-
treatment=treatment,
35+
base_test_case=base_test_case,
3636
treatment_value=treatment_value,
3737
control_value=control_value,
3838
adjustment_set=adjustment_set,
39-
outcome=outcome,
4039
df=df,
4140
effect_modifiers=None,
4241
alpha=alpha,
@@ -68,10 +67,10 @@ def estimate_iv_coefficient(self, df) -> float:
6867
outcome.
6968
"""
7069
# Estimate the total effect of instrument I on outcome Y = abI + c1
71-
ab = sm.OLS(df[self.outcome], df[[self.instrument]]).fit().params[self.instrument]
70+
ab = sm.OLS(df[self.base_test_case.outcome_variable.name], df[[self.instrument]]).fit().params[self.instrument]
7271

7372
# Estimate the direct effect of instrument I on treatment X = aI + c1
74-
a = sm.OLS(df[self.treatment], df[[self.instrument]]).fit().params[self.instrument]
73+
a = sm.OLS(df[self.base_test_case.treatment_variable.name], df[[self.instrument]]).fit().params[self.instrument]
7574

7675
# Estimate the coefficient of I on X by cancelling
7776
return ab / a

causal_testing/estimation/ipcw_estimator.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from lifelines import CoxPHFitter
1212

1313
from causal_testing.estimation.abstract_estimator import Estimator
14+
from causal_testing.testing.base_test_case import BaseTestCase
15+
from causal_testing.specification.variable import Input, Output
1416

1517
logger = logging.getLogger(__name__)
1618

@@ -56,13 +58,12 @@ def __init__(
5658
treatment) with the most elements multiplied by `timesteps_per_observation`.
5759
"""
5860
super().__init__(
59-
[var for _, var, _ in treatment_strategy],
60-
[val for _, _, val in treatment_strategy],
61-
[val for _, _, val in control_strategy],
62-
None,
63-
outcome,
64-
df,
65-
None,
61+
base_test_case=BaseTestCase(Input("_", float), Output(outcome, float)),
62+
treatment_value=[val for _, _, val in treatment_strategy],
63+
control_value=[val for _, _, val in control_strategy],
64+
adjustment_set=None,
65+
df=df,
66+
effect_modifiers=None,
6667
alpha=alpha,
6768
query="",
6869
)

0 commit comments

Comments
 (0)