Skip to content

Commit c834ba5

Browse files
add execute_test_suite
1 parent fc10813 commit c834ba5

File tree

1 file changed

+35
-26
lines changed

1 file changed

+35
-26
lines changed

causal_testing/testing/causal_test_engine.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from causal_testing.testing.causal_test_case import CausalTestCase
88
from causal_testing.testing.causal_test_outcome import CausalTestResult
99
from causal_testing.testing.estimators import Estimator
10+
from causal_testing.testing.base_causal_test import BaseCausalTest
1011

1112
logger = logging.getLogger(__name__)
1213

@@ -50,10 +51,10 @@ def __init__(self, causal_specification: CausalSpecification, data_collector: Da
5051
)
5152
self.data_collector = data_collector
5253
self.scenario_execution_data_df = self.data_collector.collect_data(**kwargs)
53-
self.minimal_adjustment_set = set()
54+
5455

5556
def execute_test_suite(
56-
self, test_suite: dict[dict[CausalTestCase], dict[Estimator], str]
57+
self, test_suite: dict[dict[CausalTestCase], dict[Estimator], str]
5758
) -> list[CausalTestResult]:
5859
"""Execute a suite of causal tests and return the results in a list"""
5960
if self.scenario_execution_data_df.empty:
@@ -63,29 +64,37 @@ def execute_test_suite(
6364

6465
logger.info("treatment: %s", edge.treatment_variable)
6566
logger.info("outcome: %s", edge.outcome_variable)
66-
minimal_adjustment_set = self.casual_dag.identification(edge)
67-
minimal_adjustment_set = minimal_adjustment_set - {v.name for v in edge.treatment_variable}
68-
minimal_adjustment_set = minimal_adjustment_set - {v.name for v in edge.outcome_variable}
69-
70-
variables_for_positivity = list(minimal_adjustment_set) + edge.treatment_variable + edge.outcome_variable
67+
minimal_adjustment_set = self.causal_dag.identification(edge)
68+
minimal_adjustment_set = minimal_adjustment_set - set(edge.treatment_variable.name)
69+
minimal_adjustment_set = minimal_adjustment_set - set(edge.outcome_variable.name)
7170

71+
variables_for_positivity = list(minimal_adjustment_set) + [edge.treatment_variable.name] + [
72+
edge.outcome_variable.name]
7273
if self._check_positivity_violation(variables_for_positivity):
7374
# TODO: We should allow users to continue because positivity can be overcome with parametric models
7475
# TODO: When we implement causal contracts, we should also note the positivity violation there
7576
raise Exception("POSITIVITY VIOLATION -- Cannot proceed.")
7677

77-
for estimator in edge["estimators"]:
78-
if estimator.df is None:
79-
estimator.df = self.scenario_execution_data_df
80-
81-
for test in edge["tests"]:
82-
logger.info("minimal_adjustment_set: %s", self.minimal_adjustment_set)
83-
causal_test_result = self._return_causal_test_results(test_suite.estimate_type, estimator)
78+
estimators = test_suite[edge]["estimators"]
79+
tests = test_suite[edge]["tests"]
80+
estimate_type = test_suite[edge]["estimate_type"]
81+
82+
for EstimatorClass in estimators:
83+
84+
for test in tests:
85+
treatment_variable = list(test.treatment_input_configuration.keys())[0]
86+
treatment_value = list(test.treatment_input_configuration.values())[0]
87+
control_value = list(test.control_input_configuration.values())[0]
88+
estimator = EstimatorClass((treatment_variable.name,), treatment_value, control_value,
89+
minimal_adjustment_set, (test.outcome_variable.name,))
90+
if estimator.df is None:
91+
estimator.df = self.scenario_execution_data_df
92+
causal_test_result = self._return_causal_test_results(estimate_type, estimator, test)
8493
causal_test_results.append(causal_test_result)
8594
return causal_test_results
8695

8796
def execute_test(
88-
self, estimator: Estimator, causal_test_case: CausalTestCase, estimate_type: str = "ate"
97+
self, estimator: type(Estimator), causal_test_case: CausalTestCase, estimate_type: str = "ate"
8998
) -> CausalTestResult:
9099
"""Execute a causal test case and return the causal test result.
91100
@@ -108,26 +117,26 @@ def execute_test(
108117
raise Exception("No data has been loaded. Please call load_data prior to executing a causal test case.")
109118
if estimator.df is None:
110119
estimator.df = self.scenario_execution_data_df
111-
treatment_variables = list(causal_test_case.control_input_configuration)
112-
treatments = [v.name for v in treatment_variables]
113-
outcomes = [v.name for v in causal_test_case.outcome_variables]
120+
treatment_variable = list(causal_test_case.control_input_configuration.keys())[0]
121+
treatments = treatment_variable.name
122+
outcome_variable = causal_test_case.outcome_variable
114123

115124
logger.info("treatments: %s", treatments)
116-
logger.info("outcomes: %s", outcomes)
117-
logger.info("minimal_adjustment_set: %s", self.minimal_adjustment_set)
118-
119-
minimal_adjustment_set = self.minimal_adjustment_set - {
125+
logger.info("outcomes: %s", outcome_variable)
126+
minimal_adjustment_set = self.causal_dag.identification(BaseCausalTest(treatment_variable, outcome_variable))
127+
minimal_adjustment_set = minimal_adjustment_set - {
120128
v.name for v in causal_test_case.control_input_configuration
121129
}
122-
minimal_adjustment_set = minimal_adjustment_set - {v.name for v in causal_test_case.outcome_variables}
130+
minimal_adjustment_set = minimal_adjustment_set - set(outcome_variable.name)
123131
assert all(
124132
(v.name not in minimal_adjustment_set for v in causal_test_case.control_input_configuration)
125133
), "Treatment vars in adjustment set"
126-
assert all(
127-
(v.name not in minimal_adjustment_set for v in causal_test_case.outcome_variables)
134+
assert (
135+
outcome_variable not in minimal_adjustment_set
128136
), "Outcome vars in adjustment set"
137+
variables_for_positivity = list(minimal_adjustment_set) + [treatment_variable.name] + [
138+
outcome_variable.name]
129139

130-
variables_for_positivity = list(minimal_adjustment_set) + treatments + outcomes
131140
if self._check_positivity_violation(variables_for_positivity):
132141
# TODO: We should allow users to continue because positivity can be overcome with parametric models
133142
# TODO: When we implement causal contracts, we should also note the positivity violation there

0 commit comments

Comments
 (0)