@@ -68,9 +68,8 @@ def set_paths(self, json_path: str, dag_path: str, data_paths: list[str] = None)
68
68
data_paths = []
69
69
self .input_paths = JsonClassPaths (json_path = json_path , dag_path = dag_path , data_paths = data_paths )
70
70
71
- def setup (self , scenario : Scenario ):
71
+ def setup (self , scenario : Scenario , data = [] ):
72
72
"""Function to populate all the necessary parts of the json_class needed to execute tests"""
73
- data = []
74
73
self .scenario = scenario
75
74
self ._get_scenario_variables ()
76
75
self .scenario .setup_treatment_variables ()
@@ -85,7 +84,7 @@ def setup(self, scenario: Scenario):
85
84
data = pd .concat ([pd .read_csv (data_file , header = 0 ) for data_file in self .input_paths .data_paths ])
86
85
if len (data ) == 0 :
87
86
raise ValueError (
88
- "No data found, either provide a path to a file containing data or manually populate the .data "
87
+ "No data found. Please either provide a path to a file containing data or manually populate the .data "
89
88
"attribute with a dataframe before calling .setup()"
90
89
)
91
90
self .data_collector = ObservationalDataCollector (self .scenario , data )
@@ -132,39 +131,18 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
132
131
test ["estimator" ] = estimators [test ["estimator" ]]
133
132
# If we have specified concrete control and treatment value
134
133
if "mutations" not in test :
135
- failed , msg = self ._run_concrete_metamorphic_test ()
134
+ failed , msg = self ._run_concrete_metamorphic_test (test , f_flag , effects , mutates )
136
135
# If we have a variable to mutate
137
136
else :
138
- outcome_variable = next (
139
- iter (test ["expected_effect" ])
140
- ) # Take first key from dictionary of expected effect
141
- base_test_case = BaseTestCase (
142
- treatment_variable = self .variables ["inputs" ][test ["treatment_variable" ]],
143
- outcome_variable = self .variables ["outputs" ][outcome_variable ],
144
- )
145
-
146
- causal_test_case = CausalTestCase (
147
- base_test_case = base_test_case ,
148
- expected_causal_effect = effects [test ["expected_effect" ][outcome_variable ]],
149
- control_value = test ["control_value" ],
150
- treatment_value = test ["treatment_value" ],
151
- estimate_type = test ["estimate_type" ],
152
- )
153
-
154
- failed , msg = self ._execute_test_case (causal_test_case = causal_test_case , test = test , f_flag = f_flag )
155
-
156
- # msg = (
157
- # f"Executing concrete test: {test['name']} \n"
158
- # + f"treatment variable: {test['treatment_variable']} \n"
159
- # + f"outcome_variable = {outcome_variable} \n"
160
- # + f"control value = {test['control_value']}, treatment value = {test['treatment_value']} \n"
161
- # + f"Result: {'FAILED' if failed else 'Passed'}"
162
- # )
163
- # print(msg)
164
- self ._append_to_file (msg , logging .INFO )
137
+ if test ["estimate_type" ] == "coefficient" :
138
+ failed , msg = self ._run_coefficient_test (test = test , f_flag = f_flag , effects = effects )
139
+ else :
140
+ failed , msg = self ._run_metamorphic_tests (
141
+ test = test , f_flag = f_flag , effects = effects , mutates = mutates
142
+ )
165
143
test ["failed" ] = failed
166
144
test ["result" ] = msg
167
- return self .test_plan ["tests" ]
145
+ return self .test_plan ["tests" ]
168
146
169
147
def _run_coefficient_test (self , test : dict , f_flag : bool , effects : dict ):
170
148
"""Builds structures and runs test case for tests with an estimate_type of 'coefficient'.
@@ -299,7 +277,7 @@ def _execute_test_case(
299
277
test_passes = causal_test_case .expected_causal_effect .apply (causal_test_result )
300
278
301
279
if "coverage" in test and test ["coverage" ]:
302
- adequacy_metric = DataAdequacy (causal_test_case , causal_test_engine , estimation_model )
280
+ adequacy_metric = DataAdequacy (causal_test_case , estimation_model , self . data_collector )
303
281
adequacy_metric .measure_adequacy ()
304
282
# self._append_to_file(f"KURTOSIS: {effect_estimate.mean()}", logging.INFO)
305
283
# self._append_to_file(f"PASSING: {sum(outcomes)}/{len(outcomes)}", logging.INFO)
@@ -331,7 +309,6 @@ def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> Estima
331
309
data. Conditions should be in the query format detailed at
332
310
https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.query.html
333
311
:returns:
334
- - causal_test_engine - Test Engine instance for the test being run
335
312
- estimation_model - Estimator instance for the test being run
336
313
"""
337
314
minimal_adjustment_set = self .causal_specification .causal_dag .identification (causal_test_case .base_test_case )
0 commit comments