Skip to content

Commit f8346ac

Browse files
refactor of causal_test_case.py
1 parent ac2c2a3 commit f8346ac

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

causal_testing/testing/causal_test_case.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33

44
from causal_testing.specification.variable import Variable
55
from causal_testing.testing.causal_test_outcome import CausalTestOutcome
6-
6+
from causal_testing.testing.base_causal_test import BaseCausalTest
77
logger = logging.getLogger(__name__)
88

99

10-
class CausalTestCase:
11-
"""
10+
class CausalTestCase(BaseCausalTest):
11+
""")
1212
A causal test case is a triple (X, Delta, Y), where X is an input configuration, Delta is an intervention, and
1313
Y is the expected causal effect on a particular output. The goal of a causal test case is to test whether the
1414
intervention Delta made to the input configuration X causes the model-under-test to produce the expected change
@@ -17,13 +17,13 @@ class CausalTestCase:
1717

1818
def __init__(
1919
self,
20-
control_input_configuration: dict[Variable:Any],
20+
base_causal_test: BaseCausalTest,
2121
expected_causal_effect: CausalTestOutcome,
22-
outcome_variables: dict[Variable],
23-
treatment_input_configuration: dict[Variable:Any] = None,
22+
control_value: Any,
23+
outcome_value: Any,
24+
treatment_value: Any = None,
2425
estimate_type: str = "ate",
2526
effect_modifier_configuration: dict[Variable:Any] = None,
26-
effect: str = "total",
2727
):
2828
"""
2929
When a CausalTestCase is initialised, it takes the intervention and applies it to the input configuration to
@@ -36,12 +36,12 @@ def __init__(
3636
:param treatment_input_configuration: The input configuration representing the treatment values of the treatment
3737
variables. That is, the input configuration *after* applying the intervention.
3838
"""
39-
self.control_input_configuration = control_input_configuration
39+
self.control_input_configuration = {base_causal_test.treatment_variable: control_value}
4040
self.expected_causal_effect = expected_causal_effect
41-
self.outcome_variables = outcome_variables
42-
self.treatment_input_configuration = treatment_input_configuration
41+
self.outcome_variable = base_causal_test.outcome_variable
42+
self.treatment_input_configuration = {base_causal_test.treatment_variable: treatment_value}
4343
self.estimate_type = estimate_type
44-
self.effect = effect
44+
self.effect = base_causal_test.effect
4545
if effect_modifier_configuration:
4646
self.effect_modifier_configuration = effect_modifier_configuration
4747
else:

0 commit comments

Comments
 (0)