@@ -55,6 +55,7 @@ def __init__(self, output_path: str, output_overwrite: bool = False):
55
55
self .causal_specification = None
56
56
self .output_path = Path (output_path )
57
57
self .check_file_exists (self .output_path , output_overwrite )
58
+ self .data_collector = None
58
59
59
60
def set_paths (self , json_path : str , dag_path : str , data_paths : list [str ] = None ):
60
61
"""
@@ -87,6 +88,8 @@ def setup(self, scenario: Scenario):
87
88
"attribute with a dataframe before calling .setup()"
88
89
)
89
90
self ._populate_metas ()
91
+ self .data_collector = ObservationalDataCollector (
92
+ self .scenario , self .data )
90
93
91
94
def _create_abstract_test_case (self , test , mutates , effects ):
92
95
assert len (test ["mutations" ]) == 1
@@ -149,14 +152,15 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
149
152
treatment_value = test ["treatment_value" ],
150
153
estimate_type = test ["estimate_type" ],
151
154
)
155
+
152
156
failed , _ = self ._execute_test_case (causal_test_case = causal_test_case , test = test , f_flag = f_flag )
153
157
154
158
msg = (
155
- f"Executing concrete test: { test ['name' ]} \n "
156
- + f"treatment variable: { test ['treatment_variable' ]} \n "
157
- + f"outcome_variable = { outcome_variable } \n "
158
- + f"control value = { test ['control_value' ]} , treatment value = { test ['treatment_value' ]} \n "
159
- + f"Result: { 'FAILED' if failed else 'Passed' } "
159
+ f"Executing concrete test: { test ['name' ]} \n "
160
+ + f"treatment variable: { test ['treatment_variable' ]} \n "
161
+ + f"outcome_variable = { outcome_variable } \n "
162
+ + f"control value = { test ['control_value' ]} , treatment value = { test ['treatment_value' ]} \n "
163
+ + f"Result: { 'FAILED' if failed else 'Passed' } "
160
164
)
161
165
print (msg )
162
166
self ._append_to_file (msg , logging .INFO )
@@ -183,12 +187,12 @@ def _run_coefficient_test(self, test: dict, f_flag: bool, effects: dict):
183
187
)
184
188
result = self ._execute_test_case (causal_test_case = causal_test_case , test = test , f_flag = f_flag )
185
189
msg = (
186
- f"Executing test: { test ['name' ]} \n "
187
- + f" { causal_test_case } \n "
188
- + " "
189
- + ("\n " ).join (str (result [1 ]).split ("\n " ))
190
- + "==============\n "
191
- + f" Result: { 'FAILED' if result [0 ] else 'Passed' } "
190
+ f"Executing test: { test ['name' ]} \n "
191
+ + f" { causal_test_case } \n "
192
+ + " "
193
+ + ("\n " ).join (str (result [1 ]).split ("\n " ))
194
+ + "==============\n "
195
+ + f" Result: { 'FAILED' if result [0 ] else 'Passed' } "
192
196
)
193
197
return msg
194
198
@@ -216,13 +220,13 @@ def _run_ate_test(self, test: dict, f_flag: bool, effects: dict, mutates: dict):
216
220
failures , _ = self ._execute_tests (concrete_tests , test , f_flag )
217
221
218
222
msg = (
219
- f"Executing test: { test ['name' ]} \n "
220
- + " abstract_test \n "
221
- + f" { abstract_test } \n "
222
- + f" { abstract_test .treatment_variable .name } ,"
223
- + f" { abstract_test .treatment_variable .distribution } \n "
224
- + f" Number of concrete tests for test case: { str (len (concrete_tests ))} \n "
225
- + f" { failures } /{ len (concrete_tests )} failed for { test ['name' ]} "
223
+ f"Executing test: { test ['name' ]} \n "
224
+ + " abstract_test \n "
225
+ + f" { abstract_test } \n "
226
+ + f" { abstract_test .treatment_variable .name } ,"
227
+ + f" { abstract_test .treatment_variable .distribution } \n "
228
+ + f" Number of concrete tests for test case: { str (len (concrete_tests ))} \n "
229
+ + f" { failures } /{ len (concrete_tests )} failed for { test ['name' ]} "
226
230
)
227
231
return msg
228
232
@@ -231,6 +235,7 @@ def _execute_tests(self, concrete_tests, test, f_flag):
231
235
details = []
232
236
if "formula" in test :
233
237
self ._append_to_file (f"Estimator formula used for test: { test ['formula' ]} " )
238
+
234
239
for concrete_test in concrete_tests :
235
240
failed , result = self ._execute_test_case (concrete_test , test , f_flag )
236
241
details .append (result )
@@ -246,7 +251,8 @@ def _populate_metas(self):
246
251
meta .populate (self .data )
247
252
248
253
def _execute_test_case (
249
- self , causal_test_case : CausalTestCase , test : Iterable [Mapping ], f_flag : bool
254
+ self , causal_test_case : CausalTestCase , test : Iterable [Mapping ],
255
+ f_flag : bool
250
256
) -> (bool , CausalTestResult ):
251
257
"""Executes a singular test case, prints the results and returns the test case result
252
258
:param causal_test_case: The concrete test case to be executed
@@ -258,10 +264,10 @@ def _execute_test_case(
258
264
"""
259
265
failed = False
260
266
261
- causal_test_engine , estimation_model = self ._setup_test (
262
- causal_test_case , test , test [ "conditions" ] if "conditions" in test else None
263
- )
264
- causal_test_result = causal_test_engine . execute_test ( estimation_model , causal_test_case )
267
+ estimation_model = self ._setup_test (causal_test_case = causal_test_case , test = test , data = self . data_collector . data )
268
+ causal_test_result = causal_test_case . execute_test ( estimator = estimation_model ,
269
+ data_collector = self . data_collector ,
270
+ causal_specification = self . causal_specification )
265
271
266
272
test_passes = causal_test_case .expected_causal_effect .apply (causal_test_result )
267
273
@@ -284,7 +290,7 @@ def _execute_test_case(
284
290
return failed , causal_test_result
285
291
286
292
def _setup_test (
287
- self , causal_test_case : CausalTestCase , test : Mapping , conditions : list [ str ] = None
293
+ self , causal_test_case : CausalTestCase , test : Mapping , data : pd . DataFrame
288
294
) -> tuple [CausalTestEngine , Estimator ]:
289
295
"""Create the necessary inputs for a single test case
290
296
:param causal_test_case: The concrete test case to be executed
@@ -296,12 +302,6 @@ def _setup_test(
296
302
- causal_test_engine - Test Engine instance for the test being run
297
303
- estimation_model - Estimator instance for the test being run
298
304
"""
299
-
300
- data_collector = ObservationalDataCollector (
301
- self .scenario , self .data .query (" & " .join (conditions )) if conditions else self .data
302
- )
303
- causal_test_engine = CausalTestEngine (self .causal_specification , data_collector , index_col = 0 )
304
-
305
305
minimal_adjustment_set = self .causal_specification .causal_dag .identification (causal_test_case .base_test_case )
306
306
treatment_var = causal_test_case .treatment_variable
307
307
minimal_adjustment_set = minimal_adjustment_set - {treatment_var }
@@ -311,14 +311,14 @@ def _setup_test(
311
311
"control_value" : causal_test_case .control_value ,
312
312
"adjustment_set" : minimal_adjustment_set ,
313
313
"outcome" : causal_test_case .outcome_variable .name ,
314
- "df" : causal_test_engine . scenario_execution_data_df ,
314
+ "df" : data ,
315
315
"effect_modifiers" : causal_test_case .effect_modifier_configuration ,
316
316
"alpha" : test ["alpha" ] if "alpha" in test else 0.05 ,
317
317
}
318
318
if "formula" in test :
319
319
estimator_kwargs ["formula" ] = test ["formula" ]
320
320
estimation_model = test ["estimator" ](** estimator_kwargs )
321
- return causal_test_engine , estimation_model
321
+ return estimation_model
322
322
323
323
def _append_to_file (self , line : str , log_level : int = None ):
324
324
"""Appends given line(s) to the current output file. If log_level is specified it also logs that message to the
@@ -370,7 +370,7 @@ def get_args(test_args=None) -> argparse.Namespace:
370
370
parser .add_argument (
371
371
"-w" ,
372
372
help = "Specify to overwrite any existing output files. This can lead to the loss of existing outputs if not "
373
- "careful" ,
373
+ "careful" ,
374
374
action = "store_true" ,
375
375
)
376
376
parser .add_argument (
0 commit comments