|
4 | 4 | import scipy
|
5 | 5 | import os
|
6 | 6 |
|
7 |
| -from causal_testing.testing.estimators import LinearRegressionEstimator |
| 7 | +from causal_testing.testing.estimators import LinearRegressionEstimator, CausalForestEstimator |
8 | 8 | from causal_testing.testing.causal_test_outcome import NoEffect, Positive
|
9 | 9 | from tests.test_helpers import remove_temp_dir_if_existent
|
10 | 10 | from causal_testing.json_front.json_class import JsonUtility, CausalVariables
|
@@ -291,6 +291,31 @@ def test_no_data_provided(self):
|
291 | 291 | with self.assertRaises(ValueError):
|
292 | 292 | json_class.setup(self.scenario)
|
293 | 293 |
|
| 294 | + def test_estimator_formula_type_check(self): |
| 295 | + example_test = { |
| 296 | + "tests": [ |
| 297 | + { |
| 298 | + "name": "test1", |
| 299 | + "mutations": {"test_input": "Increase"}, |
| 300 | + "estimator": "CausalForestEstimator", |
| 301 | + "estimate_type": "ate", |
| 302 | + "effect_modifiers": [], |
| 303 | + "expected_effect": {"test_output": "Positive"}, |
| 304 | + "skip": False, |
| 305 | + "formula": "test_output ~ test_input", |
| 306 | + } |
| 307 | + ] |
| 308 | + } |
| 309 | + self.json_class.test_plan = example_test |
| 310 | + effects = {"Positive": Positive()} |
| 311 | + mutates = { |
| 312 | + "Increase": lambda x: self.json_class.scenario.treatment_variables[x].z3 |
| 313 | + > self.json_class.scenario.variables[x].z3 |
| 314 | + } |
| 315 | + estimators = {"CausalForestEstimator": CausalForestEstimator} |
| 316 | + with self.assertRaises(TypeError): |
| 317 | + self.json_class.run_json_tests(effects=effects, mutates=mutates, estimators=estimators, f_flag=False) |
| 318 | + |
294 | 319 | def tearDown(self) -> None:
|
295 | 320 | remove_temp_dir_if_existent()
|
296 | 321 | if os.path.exists("temp_out.txt"):
|
|
0 commit comments