@@ -52,9 +52,8 @@ def __init__(self, causal_specification: CausalSpecification, data_collector: Da
52
52
self .data_collector = data_collector
53
53
self .scenario_execution_data_df = self .data_collector .collect_data (** kwargs )
54
54
55
-
56
55
def execute_test_suite (
57
- self , test_suite : dict [dict [CausalTestCase ], dict [Estimator ], str ]
56
+ self , test_suite : dict [dict [CausalTestCase ], dict [Estimator ], str ]
58
57
) -> list [CausalTestResult ]:
59
58
"""Execute a suite of causal tests and return the results in a list"""
60
59
if self .scenario_execution_data_df .empty :
@@ -69,8 +68,9 @@ def execute_test_suite(
69
68
minimal_adjustment_set = minimal_adjustment_set - set (edge .treatment_variable .name )
70
69
minimal_adjustment_set = minimal_adjustment_set - set (edge .outcome_variable .name )
71
70
72
- variables_for_positivity = list (minimal_adjustment_set ) + [edge .treatment_variable .name ] + [
73
- edge .outcome_variable .name ]
71
+ variables_for_positivity = (
72
+ list (minimal_adjustment_set ) + [edge .treatment_variable .name ] + [edge .outcome_variable .name ]
73
+ )
74
74
if self ._check_positivity_violation (variables_for_positivity ):
75
75
# TODO: We should allow users to continue because positivity can be overcome with parametric models
76
76
# TODO: When we implement causal contracts, we should also note the positivity violation there
@@ -87,8 +87,13 @@ def execute_test_suite(
87
87
treatment_variable = list (test .treatment_input_configuration .keys ())[0 ]
88
88
treatment_value = list (test .treatment_input_configuration .values ())[0 ]
89
89
control_value = list (test .control_input_configuration .values ())[0 ]
90
- estimator = EstimatorClass ((treatment_variable .name ,), treatment_value , control_value ,
91
- minimal_adjustment_set , (test .outcome_variable .name ,))
90
+ estimator = EstimatorClass (
91
+ (treatment_variable .name ,),
92
+ treatment_value ,
93
+ control_value ,
94
+ minimal_adjustment_set ,
95
+ (test .outcome_variable .name ,),
96
+ )
92
97
if estimator .df is None :
93
98
estimator .df = self .scenario_execution_data_df
94
99
causal_test_result = self ._return_causal_test_results (estimate_type , estimator , test )
@@ -99,7 +104,7 @@ def execute_test_suite(
99
104
return test_suite_results
100
105
101
106
def execute_test (
102
- self , estimator : type (Estimator ), causal_test_case : CausalTestCase , estimate_type : str = "ate"
107
+ self , estimator : type (Estimator ), causal_test_case : CausalTestCase , estimate_type : str = "ate"
103
108
) -> CausalTestResult :
104
109
"""Execute a causal test case and return the causal test result.
105
110
@@ -129,18 +134,13 @@ def execute_test(
129
134
logger .info ("treatments: %s" , treatments )
130
135
logger .info ("outcomes: %s" , outcome_variable )
131
136
minimal_adjustment_set = self .causal_dag .identification (BaseCausalTest (treatment_variable , outcome_variable ))
132
- minimal_adjustment_set = minimal_adjustment_set - {
133
- v .name for v in causal_test_case .control_input_configuration
134
- }
137
+ minimal_adjustment_set = minimal_adjustment_set - {v .name for v in causal_test_case .control_input_configuration }
135
138
minimal_adjustment_set = minimal_adjustment_set - set (outcome_variable .name )
136
139
assert all (
137
140
(v .name not in minimal_adjustment_set for v in causal_test_case .control_input_configuration )
138
141
), "Treatment vars in adjustment set"
139
- assert (
140
- outcome_variable not in minimal_adjustment_set
141
- ), "Outcome vars in adjustment set"
142
- variables_for_positivity = list (minimal_adjustment_set ) + [treatment_variable .name ] + [
143
- outcome_variable .name ]
142
+ assert outcome_variable not in minimal_adjustment_set , "Outcome vars in adjustment set"
143
+ variables_for_positivity = list (minimal_adjustment_set ) + [treatment_variable .name ] + [outcome_variable .name ]
144
144
145
145
if self ._check_positivity_violation (variables_for_positivity ):
146
146
# TODO: We should allow users to continue because positivity can be overcome with parametric models
0 commit comments