Skip to content

Commit c5d6a8f

Browse files
authored
Merge pull request #308 from CITCOM-project/jmafoster1/remove-data-collector
Jmafoster1/remove data collector
2 parents 1240348 + 15fa8ee commit c5d6a8f

File tree

73 files changed

+1377
-5268
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+1377
-5268
lines changed

causal_testing/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
"""
22
This is the CausalTestingFramework Module
33
It contains 5 subpackages:
4-
data_collection
5-
generation
6-
json_front
4+
estimation
75
specification
6+
surrogate
87
testing
8+
utils
99
"""
1010

1111
import logging

causal_testing/data_collection/__init__.py

Whitespace-only changes.

causal_testing/data_collection/data_collector.py

Lines changed: 0 additions & 161 deletions
This file was deleted.

causal_testing/estimation/abstract_estimator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import pandas as pd
88

9+
from causal_testing.testing.base_test_case import BaseTestCase
10+
911
logger = logging.getLogger(__name__)
1012

1113

@@ -29,22 +31,21 @@ class Estimator(ABC):
2931

3032
def __init__(
3133
# pylint: disable=too-many-arguments
34+
# pylint: disable=R0801
3235
self,
33-
treatment: str,
36+
base_test_case: BaseTestCase,
3437
treatment_value: float,
3538
control_value: float,
3639
adjustment_set: set,
37-
outcome: str,
3840
df: pd.DataFrame = None,
3941
effect_modifiers: dict[str:Any] = None,
4042
alpha: float = 0.05,
4143
query: str = "",
4244
):
43-
self.treatment = treatment
45+
self.base_test_case = base_test_case
4446
self.treatment_value = treatment_value
4547
self.control_value = control_value
4648
self.adjustment_set = adjustment_set
47-
self.outcome = outcome
4849
self.alpha = alpha
4950
self.df = df.query(query) if query else df
5051

causal_testing/estimation/abstract_regression_estimator.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from causal_testing.specification.variable import Variable
1212
from causal_testing.estimation.abstract_estimator import Estimator
13+
from causal_testing.testing.base_test_case import BaseTestCase
1314

1415
logger = logging.getLogger(__name__)
1516

@@ -22,23 +23,22 @@ class RegressionEstimator(Estimator):
2223
def __init__(
2324
# pylint: disable=too-many-arguments
2425
self,
25-
treatment: str,
26+
base_test_case: BaseTestCase,
2627
treatment_value: float,
2728
control_value: float,
2829
adjustment_set: set,
29-
outcome: str,
3030
df: pd.DataFrame = None,
3131
effect_modifiers: dict[Variable:Any] = None,
3232
formula: str = None,
3333
alpha: float = 0.05,
3434
query: str = "",
3535
):
36+
# pylint: disable=R0801
3637
super().__init__(
37-
treatment=treatment,
38+
base_test_case=base_test_case,
3839
treatment_value=treatment_value,
3940
control_value=control_value,
4041
adjustment_set=adjustment_set,
41-
outcome=outcome,
4242
df=df,
4343
effect_modifiers=effect_modifiers,
4444
alpha=alpha,
@@ -53,8 +53,10 @@ def __init__(
5353
if formula is not None:
5454
self.formula = formula
5555
else:
56-
terms = [treatment] + sorted(list(adjustment_set)) + sorted(list(effect_modifiers))
57-
self.formula = f"{outcome} ~ {'+'.join(terms)}"
56+
terms = (
57+
[base_test_case.treatment_variable.name] + sorted(list(adjustment_set)) + sorted(list(effect_modifiers))
58+
)
59+
self.formula = f"{base_test_case.outcome_variable.name} ~ {'+'.join(terms)}"
5860

5961
@property
6062
@abstractmethod
@@ -104,7 +106,7 @@ def _predict(self, data=None, adjustment_config: dict = None) -> pd.DataFrame:
104106

105107
x = pd.DataFrame(columns=self.df.columns)
106108
x["Intercept"] = 1 # self.intercept
107-
x[self.treatment] = [self.treatment_value, self.control_value]
109+
x[self.base_test_case.treatment_variable.name] = [self.treatment_value, self.control_value]
108110

109111
for k, v in adjustment_config.items():
110112
x[k] = v
@@ -116,5 +118,5 @@ def _predict(self, data=None, adjustment_config: dict = None) -> pd.DataFrame:
116118
x = pd.get_dummies(x, columns=[col], drop_first=True)
117119

118120
# This has to be here in case the treatment variable is in an I(...) block in the self.formula
119-
x[self.treatment] = [self.treatment_value, self.control_value]
121+
x[self.base_test_case.treatment_variable.name] = [self.treatment_value, self.control_value]
120122
return model.get_prediction(x).summary_frame()

causal_testing/estimation/cubic_spline_estimator.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from causal_testing.specification.variable import Variable
1010
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
11+
from causal_testing.testing.base_test_case import BaseTestCase
1112

1213
logger = logging.getLogger(__name__)
1314

@@ -20,11 +21,10 @@ class CubicSplineRegressionEstimator(LinearRegressionEstimator):
2021
def __init__(
2122
# pylint: disable=too-many-arguments
2223
self,
23-
treatment: str,
24+
base_test_case: BaseTestCase,
2425
treatment_value: float,
2526
control_value: float,
2627
adjustment_set: set,
27-
outcome: str,
2828
basis: int,
2929
df: pd.DataFrame = None,
3030
effect_modifiers: dict[Variable:Any] = None,
@@ -33,7 +33,7 @@ def __init__(
3333
expected_relationship=None,
3434
):
3535
super().__init__(
36-
treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers, formula, alpha
36+
base_test_case, treatment_value, control_value, adjustment_set, df, effect_modifiers, formula, alpha
3737
)
3838

3939
self.expected_relationship = expected_relationship
@@ -42,8 +42,10 @@ def __init__(
4242
effect_modifiers = []
4343

4444
if formula is None:
45-
terms = [treatment] + sorted(list(adjustment_set)) + sorted(list(effect_modifiers))
46-
self.formula = f"{outcome} ~ cr({'+'.join(terms)}, df={basis})"
45+
terms = (
46+
[base_test_case.treatment_variable.name] + sorted(list(adjustment_set)) + sorted(list(effect_modifiers))
47+
)
48+
self.formula = f"{base_test_case.outcome_variable.name} ~ cr({'+'.join(terms)}, df={basis})"
4749

4850
def estimate_ate_calculated(self, adjustment_config: dict = None) -> pd.Series:
4951
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
@@ -59,7 +61,7 @@ def estimate_ate_calculated(self, adjustment_config: dict = None) -> pd.Series:
5961
"""
6062
model = self._run_regression()
6163

62-
x = {"Intercept": 1, self.treatment: self.treatment_value}
64+
x = {"Intercept": 1, self.base_test_case.treatment_variable.name: self.treatment_value}
6365
if adjustment_config is not None:
6466
for k, v in adjustment_config.items():
6567
x[k] = v
@@ -69,7 +71,7 @@ def estimate_ate_calculated(self, adjustment_config: dict = None) -> pd.Series:
6971

7072
treatment = model.predict(x).iloc[0]
7173

72-
x[self.treatment] = self.control_value
74+
x[self.base_test_case.treatment_variable.name] = self.control_value
7375
control = model.predict(x).iloc[0]
7476

7577
return pd.Series(treatment - control)

0 commit comments

Comments
 (0)