Skip to content

Commit 1d93d00

Browse files
Initial refactor of execute_test_case to causal_test_case.py
1 parent e028e84 commit 1d93d00

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

causal_testing/testing/causal_test_case.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,14 @@
22
import logging
33
from typing import Any
44

5+
import pandas as pd
6+
57
from causal_testing.specification.variable import Variable
68
from causal_testing.testing.causal_test_outcome import CausalTestOutcome
79
from causal_testing.testing.base_test_case import BaseTestCase
10+
from causal_testing.testing.estimators import Estimator
11+
from causal_testing.testing.causal_test_result import CausalTestResult
12+
from causal_testing.data_collection.data_collector import DataCollector
813

914
logger = logging.getLogger(__name__)
1015

@@ -73,6 +78,35 @@ def get_treatment_value(self):
7378
"""Return the treatment value of the treatment variable in this causal test case."""
7479
return self.treatment_value
7580

81+
def execute_test(self, estimator: type(Estimator), dataframe: pd.DataFrame) -> CausalTestResult:
82+
"""Execute a causal test case and return the causal test result.
83+
84+
:param estimator: A reference to an Estimator class.
85+
:param causal_test_case: The CausalTestCase object to be tested
86+
:return causal_test_result: A CausalTestResult for the executed causal test case.
87+
"""
88+
if self.scenario_execution_data_df.empty:
89+
raise ValueError("No data has been loaded. Please call load_data prior to executing a causal test case.")
90+
if estimator.df is None:
91+
estimator.df = dataframe
92+
treatment_variable = self.treatment_variable
93+
treatments = treatment_variable.name
94+
outcome_variable = self.outcome_variable
95+
96+
logger.info("treatments: %s", treatments)
97+
logger.info("outcomes: %s", outcome_variable)
98+
minimal_adjustment_set = self.causal_dag.identification(BaseTestCase(treatment_variable, outcome_variable))
99+
minimal_adjustment_set = minimal_adjustment_set - set(treatment_variable.name)
100+
minimal_adjustment_set = minimal_adjustment_set - set(outcome_variable.name)
101+
102+
variables_for_positivity = list(minimal_adjustment_set) + [treatment_variable.name] + [outcome_variable.name]
103+
104+
if self._check_positivity_violation(variables_for_positivity):
105+
raise ValueError("POSITIVITY VIOLATION -- Cannot proceed.")
106+
107+
causal_test_result = self._return_causal_test_results(estimator)
108+
return causal_test_result
109+
76110
def __str__(self):
77111
treatment_config = {self.treatment_variable.name: self.treatment_value}
78112
control_config = {self.treatment_variable.name: self.control_value}

0 commit comments

Comments
 (0)