Skip to content

Commit 431ac7b

Browse files
committed
Changed it to use post_init
1 parent b6ca860 commit 431ac7b

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

causal_testing/testing/base_test_case.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@ class BaseTestCase:
1414
:param effect: A string representing the effect, current support effects are 'direct' and 'total'
1515
"""
1616

17-
def __init__(self, treatment_variable: Variable, outcome_variable: Variable, effect: str = Effect.TOTAL.value):
18-
if treatment_variable == outcome_variable:
19-
raise ValueError(f"Treatment variable {treatment_variable} cannot also be the outcome.")
20-
self.treatment_variable = treatment_variable
21-
self.outcome_variable = outcome_variable
22-
self.effect = effect
17+
treatment_variable: Variable
18+
outcome_variable: Variable
19+
effect: str = Effect.TOTAL.value
20+
21+
def __post_init__(self):
22+
if self.treatment_variable == self.outcome_variable:
23+
raise ValueError(f"Treatment variable {self.treatment_variable} cannot also be the outcome variable.")

tests/testing_tests/test_causal_test_case.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ def setUp(self) -> None:
9898
def tearDown(self) -> None:
9999
shutil.rmtree(self.temp_dir_path)
100100

101+
def test_invalid_base_test_case(self):
102+
with self.assertRaises(ValueError):
103+
BaseTestCase(self.A, self.A)
104+
101105
def test_check_minimum_adjustment_set(self):
102106
"""Check that the minimum adjustment set is correctly made"""
103107
minimal_adjustment_set = self.causal_dag.identification(self.base_test_case_A_C)

0 commit comments

Comments
 (0)