Skip to content

Commit 10e4cb2

Browse files
Unittest to check estimator type when formula is used
1 parent ffb5712 commit 10e4cb2

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

causal_testing/json_front/json_class.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> Estima
309309
"""
310310
estimator_kwargs = {}
311311
if "formula" in test:
312-
if test["estimator"] != LinearRegressionEstimator or LogisticRegressionEstimator:
312+
if test["estimator"] != (LinearRegressionEstimator or LogisticRegressionEstimator):
313313
raise TypeError(
314314
"Currently only LinearRegressionEstimator and LogisticRegressionEstimator supports the use of formulas"
315315
)

tests/json_front_tests/test_json_class.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import scipy
55
import os
66

7-
from causal_testing.testing.estimators import LinearRegressionEstimator
7+
from causal_testing.testing.estimators import LinearRegressionEstimator, CausalForestEstimator
88
from causal_testing.testing.causal_test_outcome import NoEffect, Positive
99
from tests.test_helpers import remove_temp_dir_if_existent
1010
from causal_testing.json_front.json_class import JsonUtility, CausalVariables
@@ -291,6 +291,31 @@ def test_no_data_provided(self):
291291
with self.assertRaises(ValueError):
292292
json_class.setup(self.scenario)
293293

294+
def test_estimator_formula_type_check(self):
295+
example_test = {
296+
"tests": [
297+
{
298+
"name": "test1",
299+
"mutations": {"test_input": "Increase"},
300+
"estimator": "CausalForestEstimator",
301+
"estimate_type": "ate",
302+
"effect_modifiers": [],
303+
"expected_effect": {"test_output": "Positive"},
304+
"skip": False,
305+
"formula": "test_output ~ test_input",
306+
}
307+
]
308+
}
309+
self.json_class.test_plan = example_test
310+
effects = {"Positive": Positive()}
311+
mutates = {
312+
"Increase": lambda x: self.json_class.scenario.treatment_variables[x].z3
313+
> self.json_class.scenario.variables[x].z3
314+
}
315+
estimators = {"CausalForestEstimator": CausalForestEstimator}
316+
with self.assertRaises(TypeError):
317+
self.json_class.run_json_tests(effects=effects, mutates=mutates, estimators=estimators, f_flag=False)
318+
294319
def tearDown(self) -> None:
295320
remove_temp_dir_if_existent()
296321
if os.path.exists("temp_out.txt"):

0 commit comments

Comments
 (0)