5
5
import json
6
6
import logging
7
7
8
- from collections .abc import Iterable , Mapping
8
+ from collections .abc import Mapping
9
9
from dataclasses import dataclass
10
10
from pathlib import Path
11
11
from statistics import StatisticsError
@@ -86,11 +86,9 @@ def setup(self, scenario: Scenario):
86
86
"No data found, either provide a path to a file containing data or manually populate the .data "
87
87
"attribute with a dataframe before calling .setup()"
88
88
)
89
- self .data_collector = ObservationalDataCollector (
90
- self .scenario , data )
89
+ self .data_collector = ObservationalDataCollector (self .scenario , data )
91
90
self ._populate_metas ()
92
91
93
-
94
92
def _create_abstract_test_case (self , test , mutates , effects ):
95
93
assert len (test ["mutations" ]) == 1
96
94
treatment_var = next (self .scenario .variables [v ] for v in test ["mutations" ])
@@ -156,11 +154,11 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
156
154
failed , _ = self ._execute_test_case (causal_test_case = causal_test_case , test = test , f_flag = f_flag )
157
155
158
156
msg = (
159
- f"Executing concrete test: { test ['name' ]} \n "
160
- + f"treatment variable: { test ['treatment_variable' ]} \n "
161
- + f"outcome_variable = { outcome_variable } \n "
162
- + f"control value = { test ['control_value' ]} , treatment value = { test ['treatment_value' ]} \n "
163
- + f"Result: { 'FAILED' if failed else 'Passed' } "
157
+ f"Executing concrete test: { test ['name' ]} \n "
158
+ + f"treatment variable: { test ['treatment_variable' ]} \n "
159
+ + f"outcome_variable = { outcome_variable } \n "
160
+ + f"control value = { test ['control_value' ]} , treatment value = { test ['treatment_value' ]} \n "
161
+ + f"Result: { 'FAILED' if failed else 'Passed' } "
164
162
)
165
163
print (msg )
166
164
self ._append_to_file (msg , logging .INFO )
@@ -187,12 +185,12 @@ def _run_coefficient_test(self, test: dict, f_flag: bool, effects: dict):
187
185
)
188
186
result = self ._execute_test_case (causal_test_case = causal_test_case , test = test , f_flag = f_flag )
189
187
msg = (
190
- f"Executing test: { test ['name' ]} \n "
191
- + f" { causal_test_case } \n "
192
- + " "
193
- + ("\n " ).join (str (result [1 ]).split ("\n " ))
194
- + "==============\n "
195
- + f" Result: { 'FAILED' if result [0 ] else 'Passed' } "
188
+ f"Executing test: { test ['name' ]} \n "
189
+ + f" { causal_test_case } \n "
190
+ + " "
191
+ + ("\n " ).join (str (result [1 ]).split ("\n " ))
192
+ + "==============\n "
193
+ + f" Result: { 'FAILED' if result [0 ] else 'Passed' } "
196
194
)
197
195
return msg
198
196
@@ -220,13 +218,13 @@ def _run_ate_test(self, test: dict, f_flag: bool, effects: dict, mutates: dict):
220
218
failures , _ = self ._execute_tests (concrete_tests , test , f_flag )
221
219
222
220
msg = (
223
- f"Executing test: { test ['name' ]} \n "
224
- + " abstract_test \n "
225
- + f" { abstract_test } \n "
226
- + f" { abstract_test .treatment_variable .name } ,"
227
- + f" { abstract_test .treatment_variable .distribution } \n "
228
- + f" Number of concrete tests for test case: { str (len (concrete_tests ))} \n "
229
- + f" { failures } /{ len (concrete_tests )} failed for { test ['name' ]} "
221
+ f"Executing test: { test ['name' ]} \n "
222
+ + " abstract_test \n "
223
+ + f" { abstract_test } \n "
224
+ + f" { abstract_test .treatment_variable .name } ,"
225
+ + f" { abstract_test .treatment_variable .distribution } \n "
226
+ + f" Number of concrete tests for test case: { str (len (concrete_tests ))} \n "
227
+ + f" { failures } /{ len (concrete_tests )} failed for { test ['name' ]} "
230
228
)
231
229
return msg
232
230
@@ -251,8 +249,7 @@ def _populate_metas(self):
251
249
meta .populate (self .data_collector .data )
252
250
253
251
def _execute_test_case (
254
- self , causal_test_case : CausalTestCase , test : Mapping ,
255
- f_flag : bool
252
+ self , causal_test_case : CausalTestCase , test : Mapping , f_flag : bool
256
253
) -> (bool , CausalTestResult ):
257
254
"""Executes a singular test case, prints the results and returns the test case result
258
255
:param causal_test_case: The concrete test case to be executed
@@ -265,8 +262,9 @@ def _execute_test_case(
265
262
failed = False
266
263
267
264
estimation_model = self ._setup_test (causal_test_case = causal_test_case , test = test )
268
- causal_test_result = causal_test_case .execute_test (estimator = estimation_model ,
269
- data_collector = self .data_collector )
265
+ causal_test_result = causal_test_case .execute_test (
266
+ estimator = estimation_model , data_collector = self .data_collector
267
+ )
270
268
271
269
test_passes = causal_test_case .expected_causal_effect .apply (causal_test_result )
272
270
@@ -288,9 +286,7 @@ def _execute_test_case(
288
286
logger .warning (" FAILED- expected %s, got %s" , causal_test_case .expected_causal_effect , result_string )
289
287
return failed , causal_test_result
290
288
291
- def _setup_test (
292
- self , causal_test_case : CausalTestCase , test : Mapping
293
- ) -> Estimator :
289
+ def _setup_test (self , causal_test_case : CausalTestCase , test : Mapping ) -> Estimator :
294
290
"""Create the necessary inputs for a single test case
295
291
:param causal_test_case: The concrete test case to be executed
296
292
:param test: Single JSON test definition stored in a mapping (dict)
@@ -368,7 +364,7 @@ def get_args(test_args=None) -> argparse.Namespace:
368
364
parser .add_argument (
369
365
"-w" ,
370
366
help = "Specify to overwrite any existing output files. This can lead to the loss of existing outputs if not "
371
- "careful" ,
367
+ "careful" ,
372
368
action = "store_true" ,
373
369
)
374
370
parser .add_argument (
0 commit comments