22
22
from causal_testing .specification .variable import Input , Meta , Output
23
23
from causal_testing .testing .causal_test_case import CausalTestCase
24
24
from causal_testing .testing .causal_test_result import CausalTestResult
25
- from causal_testing .testing .estimators import Estimator
25
+ from causal_testing .testing .estimators import Estimator , LinearRegressionEstimator
26
26
from causal_testing .testing .base_test_case import BaseTestCase
27
27
from causal_testing .testing .causal_test_adequacy import DataAdequacy
28
28
@@ -165,12 +165,12 @@ def _run_coefficient_test(self, test: dict, f_flag: bool, effects: dict):
165
165
)
166
166
failed , result = self ._execute_test_case (causal_test_case = causal_test_case , test = test , f_flag = f_flag )
167
167
msg = (
168
- f"Executing test: { test ['name' ]} \n "
169
- + f" { causal_test_case } \n "
170
- + " "
171
- + ("\n " ).join (str (result ).split ("\n " ))
172
- + "==============\n "
173
- + f" Result: { 'FAILED' if failed else 'Passed' } "
168
+ f"Executing test: { test ['name' ]} \n "
169
+ + f" { causal_test_case } \n "
170
+ + " "
171
+ + ("\n " ).join (str (result ).split ("\n " ))
172
+ + "==============\n "
173
+ + f" Result: { 'FAILED' if failed else 'Passed' } "
174
174
)
175
175
self ._append_to_file (msg , logging .INFO )
176
176
return failed , result
@@ -192,11 +192,11 @@ def _run_concrete_metamorphic_test(self, test: dict, f_flag: bool, effects: dict
192
192
failed , msg = self ._execute_test_case (causal_test_case = causal_test_case , test = test , f_flag = f_flag )
193
193
194
194
msg = (
195
- f"Executing concrete test: { test ['name' ]} \n "
196
- + f"treatment variable: { test ['treatment_variable' ]} \n "
197
- + f"outcome_variable = { outcome_variable } \n "
198
- + f"control value = { test ['control_value' ]} , treatment value = { test ['treatment_value' ]} \n "
199
- + f"Result: { 'FAILED' if failed else 'Passed' } "
195
+ f"Executing concrete test: { test ['name' ]} \n "
196
+ + f"treatment variable: { test ['treatment_variable' ]} \n "
197
+ + f"outcome_variable = { outcome_variable } \n "
198
+ + f"control value = { test ['control_value' ]} , treatment value = { test ['treatment_value' ]} \n "
199
+ + f"Result: { 'FAILED' if failed else 'Passed' } "
200
200
)
201
201
self ._append_to_file (msg , logging .INFO )
202
202
return failed , msg
@@ -225,13 +225,13 @@ def _run_metamorphic_tests(self, test: dict, f_flag: bool, effects: dict, mutate
225
225
failures , _ = self ._execute_tests (concrete_tests , test , f_flag )
226
226
227
227
msg = (
228
- f"Executing test: { test ['name' ]} \n "
229
- + " abstract_test \n "
230
- + f" { abstract_test } \n "
231
- + f" { abstract_test .treatment_variable .name } ,"
232
- + f" { abstract_test .treatment_variable .distribution } \n "
233
- + f" Number of concrete tests for test case: { str (len (concrete_tests ))} \n "
234
- + f" { failures } /{ len (concrete_tests )} failed for { test ['name' ]} "
228
+ f"Executing test: { test ['name' ]} \n "
229
+ + " abstract_test \n "
230
+ + f" { abstract_test } \n "
231
+ + f" { abstract_test .treatment_variable .name } ,"
232
+ + f" { abstract_test .treatment_variable .distribution } \n "
233
+ + f" Number of concrete tests for test case: { str (len (concrete_tests ))} \n "
234
+ + f" { failures } /{ len (concrete_tests )} failed for { test ['name' ]} "
235
235
)
236
236
self ._append_to_file (msg , logging .INFO )
237
237
return failures , msg
@@ -257,7 +257,7 @@ def _populate_metas(self):
257
257
meta .populate (self .data_collector .data )
258
258
259
259
def _execute_test_case (
260
- self , causal_test_case : CausalTestCase , test : Mapping , f_flag : bool
260
+ self , causal_test_case : CausalTestCase , test : Mapping , f_flag : bool
261
261
) -> (bool , CausalTestResult ):
262
262
"""Executes a singular test case, prints the results and returns the test case result
263
263
:param causal_test_case: The concrete test case to be executed
@@ -309,11 +309,14 @@ def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> Estima
309
309
"""
310
310
estimator_kwargs = {}
311
311
if "formula" in test :
312
+ if test ["estimator" ] != LinearRegressionEstimator :
313
+ raise TypeError ("Currently only LinearRegressionEstimator supports the use of formulas" )
312
314
estimator_kwargs ["formula" ] = test ["formula" ]
313
315
estimator_kwargs ["adjustment_set" ] = {}
314
316
else :
315
317
minimal_adjustment_set = self .causal_specification .causal_dag .identification (
316
- causal_test_case .base_test_case )
318
+ causal_test_case .base_test_case
319
+ )
317
320
minimal_adjustment_set = minimal_adjustment_set - {causal_test_case .treatment_variable }
318
321
estimator_kwargs ["adjustment_set" ] = minimal_adjustment_set
319
322
@@ -377,7 +380,7 @@ def get_args(test_args=None) -> argparse.Namespace:
377
380
parser .add_argument (
378
381
"-w" ,
379
382
help = "Specify to overwrite any existing output files. This can lead to the loss of existing outputs if not "
380
- "careful" ,
383
+ "careful" ,
381
384
action = "store_true" ,
382
385
)
383
386
parser .add_argument (
0 commit comments