Skip to content

Commit 23c8e42

Browse files
First draft of execute_test_suite method
1 parent 93c7a91 commit 23c8e42

File tree

1 file changed

+37
-3
lines changed

1 file changed

+37
-3
lines changed

causal_testing/testing/causal_test_engine.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,38 @@ def __init__(self, causal_specification: CausalSpecification, data_collector: Da
5252
self.scenario_execution_data_df = self.data_collector.collect_data(**kwargs)
5353
self.minimal_adjustment_set = set()
5454

55+
def execute_test_suite(
56+
self, test_suite: dict[dict[CausalTestCase], dict[Estimator], str]
57+
) -> list[CausalTestResult]:
58+
"""Execute a suite of causal tests and return the results in a list"""
59+
if self.scenario_execution_data_df.empty:
60+
raise Exception("No data has been loaded. Please call load_data prior to executing a causal test case.")
61+
causal_test_results = []
62+
for edge in test_suite:
63+
64+
logger.info("treatment: %s", edge.treatment_variable)
65+
logger.info("outcome: %s", edge.outcome_variable)
66+
minimal_adjustment_set = self.casual_dag.identification(edge)
67+
minimal_adjustment_set = minimal_adjustment_set - {v.name for v in edge.treatment_variable}
68+
minimal_adjustment_set = minimal_adjustment_set - {v.name for v in edge.outcome_variable}
69+
70+
variables_for_positivity = list(minimal_adjustment_set) + edge.treatment_variable + edge.outcome_variable
71+
72+
if self._check_positivity_violation(variables_for_positivity):
73+
# TODO: We should allow users to continue because positivity can be overcome with parametric models
74+
# TODO: When we implement causal contracts, we should also note the positivity violation there
75+
raise Exception("POSITIVITY VIOLATION -- Cannot proceed.")
76+
77+
for estimator in edge["estimators"]:
78+
if estimator.df is None:
79+
estimator.df = self.scenario_execution_data_df
80+
81+
for test in edge["tests"]:
82+
logger.info("minimal_adjustment_set: %s", self.minimal_adjustment_set)
83+
causal_test_result = self._return_causal_test_results(test_suite.estimate_type, estimator)
84+
causal_test_results.append(causal_test_result)
85+
return causal_test_results
86+
5587
def execute_test(
5688
self, estimator: Estimator, causal_test_case: CausalTestCase, estimate_type: str = "ate"
5789
) -> CausalTestResult:
@@ -100,7 +132,12 @@ def execute_test(
100132
# TODO: We should allow users to continue because positivity can be overcome with parametric models
101133
# TODO: When we implement causal contracts, we should also note the positivity violation there
102134
raise Exception("POSITIVITY VIOLATION -- Cannot proceed.")
135+
causal_test_result = self._return_causal_test_results(estimate_type, estimator, causal_test_case)
136+
return causal_test_result
103137

138+
# TODO (MF) I think that the test oracle procedure should go in here.
139+
# This way, the user can supply it as a function or something, which can be applied to the result of CI
140+
def _return_causal_test_results(self, estimate_type, estimator, causal_test_case):
104141
# TODO: Some estimators also return the CATE. Find the best way to add this into the causal test engine.
105142
if estimate_type == "cate":
106143
logger.debug("calculating cate")
@@ -166,9 +203,6 @@ def execute_test(
166203
raise ValueError(f"Invalid estimate type {estimate_type}, expected 'ate', 'cate', or 'risk_ratio'")
167204
return causal_test_result
168205

169-
# TODO (MF) I think that the test oracle procedure should go in here.
170-
# This way, the user can supply it as a function or something, which can be applied to the result of CI
171-
172206
def _check_positivity_violation(self, variables_list):
173207
"""Check whether the dataframe has a positivity violation relative to the specified variables list.
174208

0 commit comments

Comments
 (0)