|
3 | 3 |
|
4 | 4 | from causal_testing.specification.variable import Variable
|
5 | 5 | from causal_testing.testing.causal_test_outcome import CausalTestOutcome
|
| 6 | +from causal_testing.testing.base_test_case import BaseTestCase |
6 | 7 |
|
7 | 8 | logger = logging.getLogger(__name__)
|
8 | 9 |
|
9 | 10 |
|
10 | 11 | class CausalTestCase:
|
11 | 12 | """
|
12 |
| - A causal test case is a triple (X, Delta, Y), where X is an input configuration, Delta is an intervention, and |
13 |
| - Y is the expected causal effect on a particular output. The goal of a causal test case is to test whether the |
14 |
| - intervention Delta made to the input configuration X causes the model-under-test to produce the expected change |
15 |
| - in Y. |
| 13 | + A CausalTestCase extends the information held in a BaseTestCase. As well as storing the treatment and outcome |
| 14 | + variables, a CausalTestCase stores the values of these variables. Also the outcome variable and value are |
| 15 | + specified. |
| 16 | +
|
| 17 | + The goal of a CausalTestCase is to test whether the intervention made to the control via the treatment causes the |
| 18 | + model-under-test to produce the expected change. The CausalTestCase structure is designed for execution using the |
| 19 | + CausalTestEngine, using either execute_test() function to execute a single test case or packing CausalTestCases into |
| 20 | + a CausalTestSuite and executing them as a batch using the execute_test_suite() function. |
16 | 21 | """
|
17 | 22 |
|
18 | 23 | def __init__(
|
19 | 24 | self,
|
20 |
| - control_input_configuration: dict[Variable:Any], |
| 25 | + base_test_case: BaseTestCase, |
21 | 26 | expected_causal_effect: CausalTestOutcome,
|
22 |
| - outcome_variables: dict[Variable], |
23 |
| - treatment_input_configuration: dict[Variable:Any] = None, |
| 27 | + control_value: Any, |
| 28 | + treatment_value: Any = None, |
24 | 29 | estimate_type: str = "ate",
|
25 | 30 | effect_modifier_configuration: dict[Variable:Any] = None,
|
26 |
| - effect: str = "total", |
27 | 31 | ):
|
28 | 32 | """
|
29 |
| - When a CausalTestCase is initialised, it takes the intervention and applies it to the input configuration to |
30 |
| - create two distinct input configurations: a control input configuration and a treatment input configuration. |
31 |
| - The former is the input configuration before applying the intervention and the latter is the input configuration |
32 |
| - after applying the intervention. |
33 |
| -
|
34 |
| - :param control_input_configuration: The input configuration representing the control values of the treatment |
35 |
| - variables. |
36 |
| - :param treatment_input_configuration: The input configuration representing the treatment values of the treatment |
37 |
| - variables. That is, the input configuration *after* applying the intervention. |
| 33 | + :param base_test_case: A BaseTestCase object consisting of a treatment variable, outcome variable and effect |
| 34 | + :param expected_causal_effect: The expected causal effect (Positive, Negative, No Effect). |
| 35 | + :param control_value: The control value for the treatment variable (before intervention). |
| 36 | + :param treatment_value: The treatment value for the treatment variable (after intervention). |
| 37 | + :param estimate_type: A string which denotes the type of estimate to return |
| 38 | + :param effect_modifier_configuration: |
38 | 39 | """
|
39 |
| - self.control_input_configuration = control_input_configuration |
| 40 | + self.base_test_case = base_test_case |
| 41 | + self.control_value = control_value |
40 | 42 | self.expected_causal_effect = expected_causal_effect
|
41 |
| - self.outcome_variables = outcome_variables |
42 |
| - self.treatment_input_configuration = treatment_input_configuration |
| 43 | + self.outcome_variable = base_test_case.outcome_variable |
| 44 | + self.treatment_variable = base_test_case.treatment_variable |
| 45 | + self.treatment_value = treatment_value |
43 | 46 | self.estimate_type = estimate_type
|
44 |
| - self.effect = effect |
| 47 | + self.effect = base_test_case.effect |
| 48 | + |
45 | 49 | if effect_modifier_configuration:
|
46 | 50 | self.effect_modifier_configuration = effect_modifier_configuration
|
47 | 51 | else:
|
48 | 52 | self.effect_modifier_configuration = dict()
|
49 |
| - assert ( |
50 |
| - self.control_input_configuration.keys() == self.treatment_input_configuration.keys() |
51 |
| - ), "Control and treatment input configurations must have the same keys." |
52 | 53 |
|
53 |
| - def get_treatment_variables(self): |
54 |
| - """Return a list of the treatment variables (as strings) for this causal test case.""" |
55 |
| - return [v.name for v in self.control_input_configuration] |
| 54 | + def get_treatment_variable(self): |
| 55 | + """Return the treatment variable name (as string) for this causal test case""" |
| 56 | + return self.treatment_variable.name |
56 | 57 |
|
57 |
| - def get_outcome_variables(self): |
58 |
| - """Return a list of the outcome variables (as strings) for this causal test case.""" |
59 |
| - return [v.name for v in self.outcome_variables] |
| 58 | + def get_outcome_variable(self): |
| 59 | + """Return the outcome variable name (as string) for this causal test case.""" |
| 60 | + return self.outcome_variable.name |
60 | 61 |
|
61 |
| - def get_control_values(self): |
62 |
| - """Return a list of the control values for each treatment variable in this causal test case.""" |
63 |
| - return list(self.control_input_configuration.values()) |
| 62 | + def get_control_value(self): |
| 63 | + """Return a the control value of the treatment variable in this causal test case.""" |
| 64 | + return self.control_value |
64 | 65 |
|
65 |
| - def get_treatment_values(self): |
66 |
| - """Return a list of the treatment values for each treatment variable in this causal test case.""" |
67 |
| - return list(self.treatment_input_configuration.values()) |
| 66 | + def get_treatment_value(self): |
| 67 | + """Return the treatment value of the treatment variable in this causal test case.""" |
| 68 | + return self.treatment_value |
68 | 69 |
|
69 | 70 | def __str__(self):
|
70 |
| - treatment_config = {k.name: v for k, v in self.treatment_input_configuration.items()} |
71 |
| - control_config = {k.name: v for k, v in self.control_input_configuration.items()} |
| 71 | + treatment_config = {self.treatment_variable.name: self.treatment_value} |
| 72 | + control_config = {self.treatment_variable.name: self.control_value} |
| 73 | + outcome_variable = {self.outcome_variable} |
72 | 74 | return (
|
73 | 75 | f"Running {treatment_config} instead of {control_config} should cause the following "
|
74 |
| - f"changes to {self.outcome_variables}: {self.expected_causal_effect}." |
| 76 | + f"changes to {outcome_variable}: {self.expected_causal_effect}." |
75 | 77 | )
|
0 commit comments