Skip to content

Commit bf23559

Browse files
black
1 parent 29e6dcd commit bf23559

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

causal_testing/testing/causal_test_engine.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,8 @@ def __init__(self, causal_specification: CausalSpecification, data_collector: Da
5252
self.data_collector = data_collector
5353
self.scenario_execution_data_df = self.data_collector.collect_data(**kwargs)
5454

55-
5655
def execute_test_suite(
57-
self, test_suite: dict[dict[CausalTestCase], dict[Estimator], str]
56+
self, test_suite: dict[dict[CausalTestCase], dict[Estimator], str]
5857
) -> list[CausalTestResult]:
5958
"""Execute a suite of causal tests and return the results in a list"""
6059
if self.scenario_execution_data_df.empty:
@@ -69,8 +68,9 @@ def execute_test_suite(
6968
minimal_adjustment_set = minimal_adjustment_set - set(edge.treatment_variable.name)
7069
minimal_adjustment_set = minimal_adjustment_set - set(edge.outcome_variable.name)
7170

72-
variables_for_positivity = list(minimal_adjustment_set) + [edge.treatment_variable.name] + [
73-
edge.outcome_variable.name]
71+
variables_for_positivity = (
72+
list(minimal_adjustment_set) + [edge.treatment_variable.name] + [edge.outcome_variable.name]
73+
)
7474
if self._check_positivity_violation(variables_for_positivity):
7575
# TODO: We should allow users to continue because positivity can be overcome with parametric models
7676
# TODO: When we implement causal contracts, we should also note the positivity violation there
@@ -87,8 +87,13 @@ def execute_test_suite(
8787
treatment_variable = list(test.treatment_input_configuration.keys())[0]
8888
treatment_value = list(test.treatment_input_configuration.values())[0]
8989
control_value = list(test.control_input_configuration.values())[0]
90-
estimator = EstimatorClass((treatment_variable.name,), treatment_value, control_value,
91-
minimal_adjustment_set, (test.outcome_variable.name,))
90+
estimator = EstimatorClass(
91+
(treatment_variable.name,),
92+
treatment_value,
93+
control_value,
94+
minimal_adjustment_set,
95+
(test.outcome_variable.name,),
96+
)
9297
if estimator.df is None:
9398
estimator.df = self.scenario_execution_data_df
9499
causal_test_result = self._return_causal_test_results(estimate_type, estimator, test)
@@ -99,7 +104,7 @@ def execute_test_suite(
99104
return test_suite_results
100105

101106
def execute_test(
102-
self, estimator: type(Estimator), causal_test_case: CausalTestCase, estimate_type: str = "ate"
107+
self, estimator: type(Estimator), causal_test_case: CausalTestCase, estimate_type: str = "ate"
103108
) -> CausalTestResult:
104109
"""Execute a causal test case and return the causal test result.
105110
@@ -129,18 +134,13 @@ def execute_test(
129134
logger.info("treatments: %s", treatments)
130135
logger.info("outcomes: %s", outcome_variable)
131136
minimal_adjustment_set = self.causal_dag.identification(BaseCausalTest(treatment_variable, outcome_variable))
132-
minimal_adjustment_set = minimal_adjustment_set - {
133-
v.name for v in causal_test_case.control_input_configuration
134-
}
137+
minimal_adjustment_set = minimal_adjustment_set - {v.name for v in causal_test_case.control_input_configuration}
135138
minimal_adjustment_set = minimal_adjustment_set - set(outcome_variable.name)
136139
assert all(
137140
(v.name not in minimal_adjustment_set for v in causal_test_case.control_input_configuration)
138141
), "Treatment vars in adjustment set"
139-
assert (
140-
outcome_variable not in minimal_adjustment_set
141-
), "Outcome vars in adjustment set"
142-
variables_for_positivity = list(minimal_adjustment_set) + [treatment_variable.name] + [
143-
outcome_variable.name]
142+
assert outcome_variable not in minimal_adjustment_set, "Outcome vars in adjustment set"
143+
variables_for_positivity = list(minimal_adjustment_set) + [treatment_variable.name] + [outcome_variable.name]
144144

145145
if self._check_positivity_violation(variables_for_positivity):
146146
# TODO: We should allow users to continue because positivity can be overcome with parametric models

0 commit comments

Comments
 (0)