Skip to content

Commit b761795

Browse files
Pass in datacollector not dataframe
1 parent 51c5c10 commit b761795

File tree

2 files changed

+68
-18
lines changed

2 files changed

+68
-18
lines changed

causal_testing/data_collection/data_collector.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ class ObservationalDataCollector(DataCollector):
140140
def __init__(self, scenario: Scenario, data: pd.DataFrame):
141141
super().__init__(scenario)
142142
self.data = data
143+
self.data_checked = False
143144

144145
def collect_data(self, **kwargs) -> pd.DataFrame:
145146
"""Read a pandas dataframe and filter to remove
@@ -149,7 +150,6 @@ def collect_data(self, **kwargs) -> pd.DataFrame:
149150
150151
:return: A pandas dataframe containing execution data that is valid for the scenario-under-test.
151152
"""
152-
153153
execution_data_df = self.data
154154
for meta in self.scenario.metas():
155155
if meta.name not in self.data:
@@ -158,4 +158,5 @@ def collect_data(self, **kwargs) -> pd.DataFrame:
158158
for var_name, var in self.scenario.variables.items():
159159
if issubclass(var.datatype, Enum):
160160
scenario_execution_data_df[var_name] = [var.datatype(x) for x in scenario_execution_data_df[var_name]]
161-
return scenario_execution_data_df
161+
self.data_checked = True
162+
self.data = scenario_execution_data_df

causal_testing/testing/causal_test_case.py

Lines changed: 65 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99
from causal_testing.testing.base_test_case import BaseTestCase
1010
from causal_testing.testing.estimators import Estimator
1111
from causal_testing.testing.causal_test_result import CausalTestResult
12-
from causal_testing.data_collection.data_collector import DataCollector
12+
from causal_testing.data_collection.data_collector import ObservationalDataCollector
13+
from causal_testing.specification.causal_dag import CausalDAG
14+
from causal_testing.specification.scenario import Scenario
1315

16+
from causal_testing.specification.causal_specification import CausalSpecification
1417
logger = logging.getLogger(__name__)
1518

1619

@@ -28,15 +31,15 @@ class CausalTestCase:
2831
"""
2932

3033
def __init__(
31-
# pylint: disable=too-many-arguments
32-
self,
33-
base_test_case: BaseTestCase,
34-
expected_causal_effect: CausalTestOutcome,
35-
control_value: Any = None,
36-
treatment_value: Any = None,
37-
estimate_type: str = "ate",
38-
estimate_params: dict = None,
39-
effect_modifier_configuration: dict[Variable:Any] = None,
34+
# pylint: disable=too-many-arguments
35+
self,
36+
base_test_case: BaseTestCase,
37+
expected_causal_effect: CausalTestOutcome,
38+
control_value: Any = None,
39+
treatment_value: Any = None,
40+
estimate_type: str = "ate",
41+
estimate_params: dict = None,
42+
effect_modifier_configuration: dict[Variable:Any] = None,
4043
):
4144
"""
4245
:param base_test_case: A BaseTestCase object consisting of a treatment variable, outcome variable and effect
@@ -78,35 +81,81 @@ def get_treatment_value(self):
7881
"""Return the treatment value of the treatment variable in this causal test case."""
7982
return self.treatment_value
8083

81-
def execute_test(self, estimator: type(Estimator), dataframe: pd.DataFrame) -> CausalTestResult:
84+
def execute_test(self, estimator: type(Estimator), data_collector: ObservationalDataCollector, causal_specification: CausalSpecification) -> CausalTestResult:
8285
"""Execute a causal test case and return the causal test result.
8386
8487
:param estimator: A reference to an Estimator class.
8588
:param causal_test_case: The CausalTestCase object to be tested
8689
:return causal_test_result: A CausalTestResult for the executed causal test case.
8790
"""
88-
if self.scenario_execution_data_df.empty:
89-
raise ValueError("No data has been loaded. Please call load_data prior to executing a causal test case.")
91+
if not data_collector.data_checked:
92+
data_collector.collect_data()
9093
if estimator.df is None:
91-
estimator.df = dataframe
94+
estimator.df = data_collector.data
9295
treatment_variable = self.treatment_variable
9396
treatments = treatment_variable.name
9497
outcome_variable = self.outcome_variable
9598

9699
logger.info("treatments: %s", treatments)
97100
logger.info("outcomes: %s", outcome_variable)
98-
minimal_adjustment_set = self.causal_dag.identification(BaseTestCase(treatment_variable, outcome_variable))
101+
minimal_adjustment_set = causal_specification.causal_dag.identification(BaseTestCase(treatment_variable, outcome_variable))
99102
minimal_adjustment_set = minimal_adjustment_set - set(treatment_variable.name)
100103
minimal_adjustment_set = minimal_adjustment_set - set(outcome_variable.name)
101104

102105
variables_for_positivity = list(minimal_adjustment_set) + [treatment_variable.name] + [outcome_variable.name]
103106

104-
if self._check_positivity_violation(variables_for_positivity):
107+
if self._check_positivity_violation(variables_for_positivity, causal_specification.scenario, data_collector.data):
105108
raise ValueError("POSITIVITY VIOLATION -- Cannot proceed.")
106109

107110
causal_test_result = self._return_causal_test_results(estimator)
108111
return causal_test_result
109112

113+
def _return_causal_test_results(self, estimator, causal_test_case):
114+
"""Depending on the estimator used, calculate the 95% confidence intervals and return in a causal_test_result
115+
116+
:param estimator: An Estimator class object
117+
:param causal_test_case: The concrete test case to be executed
118+
:return: a CausalTestResult object containing the confidence intervals
119+
"""
120+
if not hasattr(estimator, f"estimate_{causal_test_case.estimate_type}"):
121+
raise AttributeError(f"{estimator.__class__} has no {causal_test_case.estimate_type} method.")
122+
estimate_effect = getattr(estimator, f"estimate_{causal_test_case.estimate_type}")
123+
effect, confidence_intervals = estimate_effect(**causal_test_case.estimate_params)
124+
causal_test_result = CausalTestResult(
125+
estimator=estimator,
126+
test_value=TestValue(causal_test_case.estimate_type, effect),
127+
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
128+
confidence_intervals=confidence_intervals,
129+
)
130+
131+
return causal_test_result
132+
133+
def _check_positivity_violation(self, variables_list, scenario: Scenario, df):
134+
"""Check whether the dataframe has a positivity violation relative to the specified variables list.
135+
136+
A positivity violation occurs when there is a stratum of the dataframe which does not have any data. Put simply,
137+
if we split the dataframe into covariate sub-groups, each sub-group must contain both a treated and untreated
138+
individual. If a positivity violation occurs, causal inference is still possible using a properly specified
139+
parametric estimator. Therefore, we should not throw an exception upon violation but raise a warning instead.
140+
141+
:param variables_list: The list of variables for which positivity must be satisfied.
142+
:return: True if positivity is violated, False otherwise.
143+
"""
144+
if not (set(variables_list) - {x.name for x in scenario.hidden_variables()}).issubset(
145+
df.columns
146+
):
147+
missing_variables = set(variables_list) - set(df.columns)
148+
logger.warning(
149+
"Positivity violation: missing data for variables %s.\n"
150+
"Causal inference is only valid if a well-specified parametric model is used.\n"
151+
"Alternatively, consider restricting analysis to executions without the variables:"
152+
".",
153+
missing_variables,
154+
)
155+
return True
156+
157+
return False
158+
110159
def __str__(self):
111160
treatment_config = {self.treatment_variable.name: self.treatment_value}
112161
control_config = {self.treatment_variable.name: self.control_value}

0 commit comments

Comments
 (0)