From b6ca8606578403327e7095b35ba32abbc7210fe2 Mon Sep 17 00:00:00 2001 From: Michael Foster Date: Wed, 9 Jul 2025 09:37:35 +0100 Subject: [PATCH 1/2] Added base test case validation to check that the treatment and outcome aren't the same --- causal_testing/testing/base_test_case.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/causal_testing/testing/base_test_case.py b/causal_testing/testing/base_test_case.py index 2cc02304..d425797f 100644 --- a/causal_testing/testing/base_test_case.py +++ b/causal_testing/testing/base_test_case.py @@ -14,6 +14,9 @@ class BaseTestCase: :param effect: A string representing the effect, current support effects are 'direct' and 'total' """ - treatment_variable: Variable - outcome_variable: Variable - effect: str = Effect.TOTAL.value + def __init__(self, treatment_variable: Variable, outcome_variable: Variable, effect: str = Effect.TOTAL.value): + if treatment_variable == outcome_variable: + raise ValueError(f"Treatment variable {treatment_variable} cannot also be the outcome.") + self.treatment_variable = treatment_variable + self.outcome_variable = outcome_variable + self.effect = effect From 431ac7b53ef24586e40c9d5743a798c788c573a7 Mon Sep 17 00:00:00 2001 From: Michael Foster Date: Wed, 9 Jul 2025 09:43:06 +0100 Subject: [PATCH 2/2] Changed it to use post_init --- causal_testing/testing/base_test_case.py | 13 +++++++------ tests/testing_tests/test_causal_test_case.py | 4 ++++ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/causal_testing/testing/base_test_case.py b/causal_testing/testing/base_test_case.py index d425797f..8838b932 100644 --- a/causal_testing/testing/base_test_case.py +++ b/causal_testing/testing/base_test_case.py @@ -14,9 +14,10 @@ class BaseTestCase: :param effect: A string representing the effect, current support effects are 'direct' and 'total' """ - def __init__(self, treatment_variable: Variable, outcome_variable: Variable, effect: str = Effect.TOTAL.value): - if treatment_variable == outcome_variable: - raise ValueError(f"Treatment variable {treatment_variable} cannot also be the outcome.") - self.treatment_variable = treatment_variable - self.outcome_variable = outcome_variable - self.effect = effect + treatment_variable: Variable + outcome_variable: Variable + effect: str = Effect.TOTAL.value + + def __post_init__(self): + if self.treatment_variable == self.outcome_variable: + raise ValueError(f"Treatment variable {self.treatment_variable} cannot also be the outcome variable.") diff --git a/tests/testing_tests/test_causal_test_case.py b/tests/testing_tests/test_causal_test_case.py index 154d76e3..8e4ef8cc 100644 --- a/tests/testing_tests/test_causal_test_case.py +++ b/tests/testing_tests/test_causal_test_case.py @@ -98,6 +98,10 @@ def setUp(self) -> None: def tearDown(self) -> None: shutil.rmtree(self.temp_dir_path) + def test_invalid_base_test_case(self): + with self.assertRaises(ValueError): + BaseTestCase(self.A, self.A) + def test_check_minimum_adjustment_set(self): """Check that the minimum adjustment set is correctly made""" minimal_adjustment_set = self.causal_dag.identification(self.base_test_case_A_C)