@@ -52,6 +52,38 @@ def __init__(self, causal_specification: CausalSpecification, data_collector: Da
52
52
self .scenario_execution_data_df = self .data_collector .collect_data (** kwargs )
53
53
self .minimal_adjustment_set = set ()
54
54
55
+ def execute_test_suite (
56
+ self , test_suite : dict [dict [CausalTestCase ], dict [Estimator ], str ]
57
+ ) -> list [CausalTestResult ]:
58
+ """Execute a suite of causal tests and return the results in a list"""
59
+ if self .scenario_execution_data_df .empty :
60
+ raise Exception ("No data has been loaded. Please call load_data prior to executing a causal test case." )
61
+ causal_test_results = []
62
+ for edge in test_suite :
63
+
64
+ logger .info ("treatment: %s" , edge .treatment_variable )
65
+ logger .info ("outcome: %s" , edge .outcome_variable )
66
+ minimal_adjustment_set = self .casual_dag .identification (edge )
67
+ minimal_adjustment_set = minimal_adjustment_set - {v .name for v in edge .treatment_variable }
68
+ minimal_adjustment_set = minimal_adjustment_set - {v .name for v in edge .outcome_variable }
69
+
70
+ variables_for_positivity = list (minimal_adjustment_set ) + edge .treatment_variable + edge .outcome_variable
71
+
72
+ if self ._check_positivity_violation (variables_for_positivity ):
73
+ # TODO: We should allow users to continue because positivity can be overcome with parametric models
74
+ # TODO: When we implement causal contracts, we should also note the positivity violation there
75
+ raise Exception ("POSITIVITY VIOLATION -- Cannot proceed." )
76
+
77
+ for estimator in edge ["estimators" ]:
78
+ if estimator .df is None :
79
+ estimator .df = self .scenario_execution_data_df
80
+
81
+ for test in edge ["tests" ]:
82
+ logger .info ("minimal_adjustment_set: %s" , self .minimal_adjustment_set )
83
+ causal_test_result = self ._return_causal_test_results (test_suite .estimate_type , estimator )
84
+ causal_test_results .append (causal_test_result )
85
+ return causal_test_results
86
+
55
87
def execute_test (
56
88
self , estimator : Estimator , causal_test_case : CausalTestCase , estimate_type : str = "ate"
57
89
) -> CausalTestResult :
@@ -100,7 +132,12 @@ def execute_test(
100
132
# TODO: We should allow users to continue because positivity can be overcome with parametric models
101
133
# TODO: When we implement causal contracts, we should also note the positivity violation there
102
134
raise Exception ("POSITIVITY VIOLATION -- Cannot proceed." )
135
+ causal_test_result = self ._return_causal_test_results (estimate_type , estimator , causal_test_case )
136
+ return causal_test_result
103
137
138
+ # TODO (MF) I think that the test oracle procedure should go in here.
139
+ # This way, the user can supply it as a function or something, which can be applied to the result of CI
140
+ def _return_causal_test_results (self , estimate_type , estimator , causal_test_case ):
104
141
# TODO: Some estimators also return the CATE. Find the best way to add this into the causal test engine.
105
142
if estimate_type == "cate" :
106
143
logger .debug ("calculating cate" )
@@ -166,9 +203,6 @@ def execute_test(
166
203
raise ValueError (f"Invalid estimate type { estimate_type } , expected 'ate', 'cate', or 'risk_ratio'" )
167
204
return causal_test_result
168
205
169
- # TODO (MF) I think that the test oracle procedure should go in here.
170
- # This way, the user can supply it as a function or something, which can be applied to the result of CI
171
-
172
206
def _check_positivity_violation (self , variables_list ):
173
207
"""Check whether the dataframe has a positivity violation relative to the specified variables list.
174
208
0 commit comments