|
2 | 2 | import logging
|
3 | 3 | from typing import Any
|
4 | 4 |
|
| 5 | +import pandas as pd |
| 6 | + |
5 | 7 | from causal_testing.specification.variable import Variable
|
6 | 8 | from causal_testing.testing.causal_test_outcome import CausalTestOutcome
|
7 | 9 | from causal_testing.testing.base_test_case import BaseTestCase
|
| 10 | +from causal_testing.testing.estimators import Estimator |
| 11 | +from causal_testing.testing.causal_test_result import CausalTestResult |
| 12 | +from causal_testing.data_collection.data_collector import DataCollector |
8 | 13 |
|
9 | 14 | logger = logging.getLogger(__name__)
|
10 | 15 |
|
@@ -73,6 +78,35 @@ def get_treatment_value(self):
|
73 | 78 | """Return the treatment value of the treatment variable in this causal test case."""
|
74 | 79 | return self.treatment_value
|
75 | 80 |
|
| 81 | + def execute_test(self, estimator: type(Estimator), dataframe: pd.DataFrame) -> CausalTestResult: |
| 82 | + """Execute a causal test case and return the causal test result. |
| 83 | +
|
| 84 | + :param estimator: A reference to an Estimator class. |
| 85 | + :param causal_test_case: The CausalTestCase object to be tested |
| 86 | + :return causal_test_result: A CausalTestResult for the executed causal test case. |
| 87 | + """ |
| 88 | + if self.scenario_execution_data_df.empty: |
| 89 | + raise ValueError("No data has been loaded. Please call load_data prior to executing a causal test case.") |
| 90 | + if estimator.df is None: |
| 91 | + estimator.df = dataframe |
| 92 | + treatment_variable = self.treatment_variable |
| 93 | + treatments = treatment_variable.name |
| 94 | + outcome_variable = self.outcome_variable |
| 95 | + |
| 96 | + logger.info("treatments: %s", treatments) |
| 97 | + logger.info("outcomes: %s", outcome_variable) |
| 98 | + minimal_adjustment_set = self.causal_dag.identification(BaseTestCase(treatment_variable, outcome_variable)) |
| 99 | + minimal_adjustment_set = minimal_adjustment_set - set(treatment_variable.name) |
| 100 | + minimal_adjustment_set = minimal_adjustment_set - set(outcome_variable.name) |
| 101 | + |
| 102 | + variables_for_positivity = list(minimal_adjustment_set) + [treatment_variable.name] + [outcome_variable.name] |
| 103 | + |
| 104 | + if self._check_positivity_violation(variables_for_positivity): |
| 105 | + raise ValueError("POSITIVITY VIOLATION -- Cannot proceed.") |
| 106 | + |
| 107 | + causal_test_result = self._return_causal_test_results(estimator) |
| 108 | + return causal_test_result |
| 109 | + |
76 | 110 | def __str__(self):
|
77 | 111 | treatment_config = {self.treatment_variable.name: self.treatment_value}
|
78 | 112 | control_config = {self.treatment_variable.name: self.control_value}
|
|
0 commit comments