Skip to content

Commit ed9c402

Browse files
Update causal_test_case.py and related tests
1 parent c6b676f commit ed9c402

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

causal_testing/testing/causal_test_case.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def get_treatment_variables(self):
5757

5858
def get_outcome_variables(self):
5959
"""Return a list of the outcome variables (as strings) for this causal test case."""
60-
return [v.name for v in self.outcome_variables]
60+
return [self.outcome_variable.name]
6161

6262
def get_control_values(self):
6363
"""Return a list of the control values for each treatment variable in this causal test case."""
@@ -70,7 +70,8 @@ def get_treatment_values(self):
7070
def __str__(self):
7171
treatment_config = {k.name: v for k, v in self.treatment_input_configuration.items()}
7272
control_config = {k.name: v for k, v in self.control_input_configuration.items()}
73+
outcome_variable = {self.outcome_variable}
7374
return (
7475
f"Running {treatment_config} instead of {control_config} should cause the following "
75-
f"changes to {self.outcome_variable}: {self.expected_causal_effect}."
76+
f"changes to {outcome_variable}: {self.expected_causal_effect}."
7677
)

tests/testing_tests/test_causal_test_case.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from causal_testing.specification.causal_dag import CausalDAG
77
from causal_testing.testing.causal_test_case import CausalTestCase
88
from causal_testing.testing.causal_test_outcome import ExactValue
9+
from causal_testing.testing.base_causal_test import BaseCausalTest
910

1011
class TestCausalTestEngineObservational(unittest.TestCase):
1112
""" Test the CausalTestEngine workflow using observational data.
@@ -33,11 +34,13 @@ def setUp(self) -> None:
3334

3435
# 3. Create an intervention and causal test case
3536
self.expected_causal_effect = ExactValue(4)
37+
self.base_causal_test = BaseCausalTest(A, C)
3638
self.causal_test_case = CausalTestCase(
37-
control_input_configuration={A: 0},
39+
base_causal_test=self.base_causal_test,
3840
expected_causal_effect=self.expected_causal_effect,
39-
treatment_input_configuration={A: 1},
40-
outcome_variables={C})
41+
control_value=0,
42+
treatment_value=1
43+
)
4144

4245
def test_get_treatment_variables(self):
4346
self.assertEqual(self.causal_test_case.get_treatment_variables(), ["A"])

0 commit comments

Comments
 (0)