8
8
from causal_testing .testing .causal_test_outcome import CausalTestResult
9
9
from causal_testing .testing .estimators import Estimator
10
10
from causal_testing .testing .base_causal_test import BaseCausalTest
11
+ from causal_testing .testing .causal_test_suite import CausalTestSuite
11
12
12
13
logger = logging .getLogger (__name__ )
13
14
@@ -53,13 +54,18 @@ def __init__(self, causal_specification: CausalSpecification, data_collector: Da
53
54
self .scenario_execution_data_df = self .data_collector .collect_data (** kwargs )
54
55
55
56
def execute_test_suite (
56
- self , test_suite : dict [dict [CausalTestCase ], dict [Estimator ], str ]
57
- ) -> list [CausalTestResult ]:
58
- """Execute a suite of causal tests and return the results in a list"""
57
+ self , test_suite : CausalTestSuite ) -> list [CausalTestResult ]:
58
+
59
+ """Execute a suite of causal tests and return the results in a list
60
+ :param test_suite: CasualTestSuite object
61
+ :return: test_suite results which contains a list of CausalTestResult objects
62
+
63
+ """
59
64
if self .scenario_execution_data_df .empty :
60
65
raise Exception ("No data has been loaded. Please call load_data prior to executing a causal test case." )
61
66
test_suite_results = {}
62
- for edge in test_suite :
67
+ test_suite_dict = test_suite .test_suite
68
+ for edge in test_suite_dict :
63
69
print ("edge: " )
64
70
print (edge )
65
71
logger .info ("treatment: %s" , edge .treatment_variable )
@@ -69,16 +75,16 @@ def execute_test_suite(
69
75
minimal_adjustment_set = minimal_adjustment_set - set (edge .outcome_variable .name )
70
76
71
77
variables_for_positivity = (
72
- list (minimal_adjustment_set ) + [edge .treatment_variable .name ] + [edge .outcome_variable .name ]
78
+ list (minimal_adjustment_set ) + [edge .treatment_variable .name ] + [edge .outcome_variable .name ]
73
79
)
74
80
if self ._check_positivity_violation (variables_for_positivity ):
75
81
# TODO: We should allow users to continue because positivity can be overcome with parametric models
76
82
# TODO: When we implement causal contracts, we should also note the positivity violation there
77
83
raise Exception ("POSITIVITY VIOLATION -- Cannot proceed." )
78
84
79
- estimators = test_suite [edge ]["estimators" ]
80
- tests = test_suite [edge ]["tests" ]
81
- estimate_type = test_suite [edge ]["estimate_type" ]
85
+ estimators = test_suite_dict [edge ]["estimators" ]
86
+ tests = test_suite_dict [edge ]["tests" ]
87
+ estimate_type = test_suite_dict [edge ]["estimate_type" ]
82
88
results = []
83
89
for EstimatorClass in estimators :
84
90
causal_test_results = []
@@ -104,7 +110,7 @@ def execute_test_suite(
104
110
return test_suite_results
105
111
106
112
def execute_test (
107
- self , estimator : type (Estimator ), causal_test_case : CausalTestCase , estimate_type : str = "ate"
113
+ self , estimator : type (Estimator ), causal_test_case : CausalTestCase , estimate_type : str = "ate"
108
114
) -> CausalTestResult :
109
115
"""Execute a causal test case and return the causal test result.
110
116
@@ -151,7 +157,15 @@ def execute_test(
151
157
152
158
# TODO (MF) I think that the test oracle procedure should go in here.
153
159
# This way, the user can supply it as a function or something, which can be applied to the result of CI
160
+
154
161
def _return_causal_test_results (self , estimate_type , estimator , causal_test_case ):
162
+ """Depending on the estimator used, calculate the 95% confidence intervals and return in a causal_test_result
163
+
164
+ :param estimate_type: A string which denotes the type of estimate to return
165
+ :param estimator: An Estimator class object
166
+ :param causal_test_case: The concrete test case to be executed
167
+ :return: a CausalTestResult object containing the confidence intervals
168
+ """
155
169
# TODO: Some estimators also return the CATE. Find the best way to add this into the causal test engine.
156
170
if estimate_type == "cate" :
157
171
logger .debug ("calculating cate" )
0 commit comments