Skip to content

Commit 2d02e5b

Browse files
Check formula is used only for Regression Estimator
1 parent 515d06b commit 2d02e5b

File tree

1 file changed

+25
-22
lines changed

1 file changed

+25
-22
lines changed

causal_testing/json_front/json_class.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from causal_testing.specification.variable import Input, Meta, Output
2323
from causal_testing.testing.causal_test_case import CausalTestCase
2424
from causal_testing.testing.causal_test_result import CausalTestResult
25-
from causal_testing.testing.estimators import Estimator
25+
from causal_testing.testing.estimators import Estimator, LinearRegressionEstimator
2626
from causal_testing.testing.base_test_case import BaseTestCase
2727
from causal_testing.testing.causal_test_adequacy import DataAdequacy
2828

@@ -165,12 +165,12 @@ def _run_coefficient_test(self, test: dict, f_flag: bool, effects: dict):
165165
)
166166
failed, result = self._execute_test_case(causal_test_case=causal_test_case, test=test, f_flag=f_flag)
167167
msg = (
168-
f"Executing test: {test['name']} \n"
169-
+ f" {causal_test_case} \n"
170-
+ " "
171-
+ ("\n ").join(str(result).split("\n"))
172-
+ "==============\n"
173-
+ f" Result: {'FAILED' if failed else 'Passed'}"
168+
f"Executing test: {test['name']} \n"
169+
+ f" {causal_test_case} \n"
170+
+ " "
171+
+ ("\n ").join(str(result).split("\n"))
172+
+ "==============\n"
173+
+ f" Result: {'FAILED' if failed else 'Passed'}"
174174
)
175175
self._append_to_file(msg, logging.INFO)
176176
return failed, result
@@ -192,11 +192,11 @@ def _run_concrete_metamorphic_test(self, test: dict, f_flag: bool, effects: dict
192192
failed, msg = self._execute_test_case(causal_test_case=causal_test_case, test=test, f_flag=f_flag)
193193

194194
msg = (
195-
f"Executing concrete test: {test['name']} \n"
196-
+ f"treatment variable: {test['treatment_variable']} \n"
197-
+ f"outcome_variable = {outcome_variable} \n"
198-
+ f"control value = {test['control_value']}, treatment value = {test['treatment_value']} \n"
199-
+ f"Result: {'FAILED' if failed else 'Passed'}"
195+
f"Executing concrete test: {test['name']} \n"
196+
+ f"treatment variable: {test['treatment_variable']} \n"
197+
+ f"outcome_variable = {outcome_variable} \n"
198+
+ f"control value = {test['control_value']}, treatment value = {test['treatment_value']} \n"
199+
+ f"Result: {'FAILED' if failed else 'Passed'}"
200200
)
201201
self._append_to_file(msg, logging.INFO)
202202
return failed, msg
@@ -225,13 +225,13 @@ def _run_metamorphic_tests(self, test: dict, f_flag: bool, effects: dict, mutate
225225
failures, _ = self._execute_tests(concrete_tests, test, f_flag)
226226

227227
msg = (
228-
f"Executing test: {test['name']} \n"
229-
+ " abstract_test \n"
230-
+ f" {abstract_test} \n"
231-
+ f" {abstract_test.treatment_variable.name},"
232-
+ f" {abstract_test.treatment_variable.distribution} \n"
233-
+ f" Number of concrete tests for test case: {str(len(concrete_tests))} \n"
234-
+ f" {failures}/{len(concrete_tests)} failed for {test['name']}"
228+
f"Executing test: {test['name']} \n"
229+
+ " abstract_test \n"
230+
+ f" {abstract_test} \n"
231+
+ f" {abstract_test.treatment_variable.name},"
232+
+ f" {abstract_test.treatment_variable.distribution} \n"
233+
+ f" Number of concrete tests for test case: {str(len(concrete_tests))} \n"
234+
+ f" {failures}/{len(concrete_tests)} failed for {test['name']}"
235235
)
236236
self._append_to_file(msg, logging.INFO)
237237
return failures, msg
@@ -257,7 +257,7 @@ def _populate_metas(self):
257257
meta.populate(self.data_collector.data)
258258

259259
def _execute_test_case(
260-
self, causal_test_case: CausalTestCase, test: Mapping, f_flag: bool
260+
self, causal_test_case: CausalTestCase, test: Mapping, f_flag: bool
261261
) -> (bool, CausalTestResult):
262262
"""Executes a singular test case, prints the results and returns the test case result
263263
:param causal_test_case: The concrete test case to be executed
@@ -309,11 +309,14 @@ def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> Estima
309309
"""
310310
estimator_kwargs = {}
311311
if "formula" in test:
312+
if test["estimator"] != LinearRegressionEstimator:
313+
raise TypeError("Currently only LinearRegressionEstimator supports the use of formulas")
312314
estimator_kwargs["formula"] = test["formula"]
313315
estimator_kwargs["adjustment_set"] = {}
314316
else:
315317
minimal_adjustment_set = self.causal_specification.causal_dag.identification(
316-
causal_test_case.base_test_case)
318+
causal_test_case.base_test_case
319+
)
317320
minimal_adjustment_set = minimal_adjustment_set - {causal_test_case.treatment_variable}
318321
estimator_kwargs["adjustment_set"] = minimal_adjustment_set
319322

@@ -377,7 +380,7 @@ def get_args(test_args=None) -> argparse.Namespace:
377380
parser.add_argument(
378381
"-w",
379382
help="Specify to overwrite any existing output files. This can lead to the loss of existing outputs if not "
380-
"careful",
383+
"careful",
381384
action="store_true",
382385
)
383386
parser.add_argument(

0 commit comments

Comments
 (0)