diff --git a/causal_testing/testing/base_test_case.py b/causal_testing/testing/base_test_case.py index 2cc02304..8838b932 100644 --- a/causal_testing/testing/base_test_case.py +++ b/causal_testing/testing/base_test_case.py @@ -17,3 +17,7 @@ class BaseTestCase: treatment_variable: Variable outcome_variable: Variable effect: str = Effect.TOTAL.value + + def __post_init__(self): + if self.treatment_variable == self.outcome_variable: + raise ValueError(f"Treatment variable {self.treatment_variable} cannot also be the outcome variable.") diff --git a/tests/testing_tests/test_causal_test_case.py b/tests/testing_tests/test_causal_test_case.py index 154d76e3..8e4ef8cc 100644 --- a/tests/testing_tests/test_causal_test_case.py +++ b/tests/testing_tests/test_causal_test_case.py @@ -98,6 +98,10 @@ def setUp(self) -> None: def tearDown(self) -> None: shutil.rmtree(self.temp_dir_path) + def test_invalid_base_test_case(self): + with self.assertRaises(ValueError): + BaseTestCase(self.A, self.A) + def test_check_minimum_adjustment_set(self): """Check that the minimum adjustment set is correctly made""" minimal_adjustment_set = self.causal_dag.identification(self.base_test_case_A_C)