Skip to content

Commit 1098043

Browse files
Merge pull request #103 from CITCOM-project/base-causal-test-case
This merge contains both the base_causal_test_case branch as well as the causal_test_case_refactor branch
2 parents b2ce1eb + c8c1d3f commit 1098043

18 files changed

+675
-231
lines changed

causal_testing/generation/abstract_causal_test_case.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from causal_testing.specification.variable import Variable
1111
from causal_testing.testing.causal_test_case import CausalTestCase
1212
from causal_testing.testing.causal_test_outcome import CausalTestOutcome
13+
from causal_testing.testing.base_test_case import BaseTestCase
1314

1415
from enum import Enum
1516

@@ -18,7 +19,7 @@
1819

1920
class AbstractCausalTestCase:
2021
"""
21-
An abstract test case serves as a generator for concrete test cases. Instead of having concrete conctrol
22+
An abstract test case serves as a generator for concrete test cases. Instead of having concrete control
2223
and treatment values, we instead just specify the intervention and the treatment variables. This then
2324
enables potentially infinite concrete test cases to be generated between different values of the treatment.
2425
"""
@@ -33,10 +34,11 @@ def __init__(
3334
estimate_type: str = "ate",
3435
effect: str = "total",
3536
):
36-
assert treatment_variable in scenario.variables.values(), (
37-
"Treatment variables must be a subset of variables."
38-
+ f" Instead got:\ntreatment_variable={treatment_variable}\nvariables={scenario.variables}"
39-
)
37+
if treatment_variable not in scenario.variables.values():
38+
raise ValueError(
39+
"Treatment variables must be a subset of variables."
40+
+ f" Instead got:\ntreatment_variables={treatment_variable}\nvariables={scenario.variables}"
41+
)
4042

4143
assert len(expected_causal_effect) == 1, "We currently only support tests with one causal outcome"
4244

@@ -119,16 +121,21 @@ def _generate_concrete_tests(
119121
)
120122
model = optimizer.model()
121123

124+
base_test_case = BaseTestCase(
125+
treatment_variable=self.treatment_variable,
126+
outcome_variable=list(self.expected_causal_effect.keys())[0],
127+
effect=self.effect,
128+
)
129+
122130
concrete_test = CausalTestCase(
123-
control_input_configuration={v: v.cast(model[v.z3]) for v in [self.treatment_variable]},
124-
treatment_input_configuration={
125-
v: v.cast(model[self.scenario.treatment_variables[v.name].z3]) for v in [self.treatment_variable]
126-
},
131+
base_test_case=base_test_case,
132+
control_value=self.treatment_variable.cast(model[self.treatment_variable.z3]),
133+
treatment_value=self.treatment_variable.cast(
134+
model[self.scenario.treatment_variables[self.treatment_variable.name].z3]
135+
),
127136
expected_causal_effect=list(self.expected_causal_effect.values())[0],
128-
outcome_variables=list(self.expected_causal_effect.keys()),
129137
estimate_type=self.estimate_type,
130138
effect_modifier_configuration={v: v.cast(model[v.z3]) for v in self.effect_modifiers},
131-
effect=self.effect,
132139
)
133140

134141
for v in self.scenario.inputs():
@@ -150,7 +157,7 @@ def _generate_concrete_tests(
150157
# Treatment run
151158
if rct:
152159
treatment_run = control_run.copy()
153-
treatment_run.update({k.name: v for k, v in concrete_test.treatment_input_configuration.items()})
160+
treatment_run.update({concrete_test.treatment_variable.name: concrete_test.treatment_value})
154161
treatment_run["bin"] = index
155162
runs.append(treatment_run)
156163

@@ -197,7 +204,7 @@ def generate_concrete_tests(
197204
runs = pd.concat([runs, runs_])
198205
assert concrete_tests_ not in concrete_tests, "Duplicate entries unlikely unless something went wrong"
199206

200-
control_configs = pd.DataFrame([test.control_input_configuration for test in concrete_tests])
207+
control_configs = pd.DataFrame([{test.treatment_variable: test.control_value} for test in concrete_tests])
201208
ks_stats = {
202209
var: stats.kstest(control_configs[var], var.distribution.cdf).statistic
203210
for var in control_configs.columns
@@ -220,8 +227,8 @@ def generate_concrete_tests(
220227
for var in effect_modifier_configs.columns
221228
}
222229
)
223-
control_values = [test.control_input_configuration[self.treatment_variable] for test in concrete_tests]
224-
treatment_values = [test.treatment_input_configuration[self.treatment_variable] for test in concrete_tests]
230+
control_values = [test.control_value for test in concrete_tests]
231+
treatment_values = [test.treatment_value for test in concrete_tests]
225232

226233
if self.treatment_variable.datatype is bool and set([(True, False), (False, True)]).issubset(
227234
set(zip(control_values, treatment_values))

causal_testing/json_front/json_class.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def generate_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag:
123123
concrete_tests, dummy = abstract_test.generate_concrete_tests(5, 0.05)
124124
logger.info("Executing test: %s", test["name"])
125125
logger.info(abstract_test)
126-
logger.info([(v.name, v.distribution) for v in [abstract_test.treatment_variable]])
126+
logger.info([abstract_test.treatment_variable.name, abstract_test.treatment_variable.distribution])
127127
logger.info("Number of concrete tests for test case: %s", str(len(concrete_tests)))
128128
failures = self._execute_tests(concrete_tests, estimators, test, f_flag)
129129
logger.info("%s/%s failed for %s\n", failures, len(concrete_tests), test["name"])
@@ -201,15 +201,15 @@ def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) ->
201201
"""
202202
data_collector = ObservationalDataCollector(self.modelling_scenario, self.data_path)
203203
causal_test_engine = CausalTestEngine(self.causal_specification, data_collector, index_col=0)
204-
causal_test_engine.identification(causal_test_case)
205-
treatment_vars = list(causal_test_case.treatment_input_configuration)
206-
minimal_adjustment_set = causal_test_engine.minimal_adjustment_set - {v.name for v in treatment_vars}
204+
minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_test_case)
205+
treatment_var = causal_test_case.treatment_variable
206+
minimal_adjustment_set = minimal_adjustment_set - {treatment_var}
207207
estimation_model = estimator(
208-
(list(treatment_vars)[0].name,),
209-
[causal_test_case.treatment_input_configuration[v] for v in treatment_vars][0],
210-
[causal_test_case.control_input_configuration[v] for v in treatment_vars][0],
208+
(treatment_var.name,),
209+
causal_test_case.treatment_value,
210+
causal_test_case.control_value,
211211
minimal_adjustment_set,
212-
(list(causal_test_case.outcome_variables)[0].name,),
212+
(causal_test_case.outcome_variable.name,),
213213
causal_test_engine.scenario_execution_data_df,
214214
effect_modifiers=causal_test_case.effect_modifier_configuration,
215215
)

causal_testing/specification/causal_dag.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,5 +468,28 @@ def depends_on_outputs(self, node: Node, scenario: Scenario) -> bool:
468468
return True
469469
return any([self.depends_on_outputs(n, scenario) for n in self.graph.predecessors(node)])
470470

471+
def identification(self, base_test_case):
472+
"""Identify and return the minimum adjustment set
473+
474+
:param base_test_case: A base test case instance containing the outcome_variable and the
475+
treatment_variable required for identification.
476+
:return minimal_adjustment_set: The smallest set of variables which can be adjusted for to obtain a causal
477+
estimate as opposed to a purely associational estimate.
478+
"""
479+
minimal_adjustment_sets = []
480+
if base_test_case.effect == "total":
481+
minimal_adjustment_sets = self.enumerate_minimal_adjustment_sets(
482+
[base_test_case.treatment_variable.name], [base_test_case.outcome_variable.name]
483+
)
484+
elif base_test_case.effect == "direct":
485+
minimal_adjustment_sets = self.direct_effect_adjustment_sets(
486+
[base_test_case.treatment_variable.name], [base_test_case.outcome_variable.name]
487+
)
488+
else:
489+
raise ValueError("Causal effect should be 'total' or 'direct'")
490+
491+
minimal_adjustment_set = min(minimal_adjustment_sets, key=len)
492+
return minimal_adjustment_set
493+
471494
def __str__(self):
472495
return f"Nodes: {self.graph.nodes}\nEdges: {self.graph.edges}"
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from dataclasses import dataclass
2+
from causal_testing.specification.variable import Variable
3+
from causal_testing.testing.effect import Effect
4+
5+
6+
@dataclass(frozen=True)
7+
class BaseTestCase:
8+
"""
9+
A base causal test case represents the relationship of an edge on a causal DAG.
10+
"""
11+
12+
treatment_variable: Variable
13+
outcome_variable: Variable
14+
effect: str = Effect.total.value

causal_testing/testing/causal_test_case.py

Lines changed: 41 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,73 +3,75 @@
33

44
from causal_testing.specification.variable import Variable
55
from causal_testing.testing.causal_test_outcome import CausalTestOutcome
6+
from causal_testing.testing.base_test_case import BaseTestCase
67

78
logger = logging.getLogger(__name__)
89

910

1011
class CausalTestCase:
1112
"""
12-
A causal test case is a triple (X, Delta, Y), where X is an input configuration, Delta is an intervention, and
13-
Y is the expected causal effect on a particular output. The goal of a causal test case is to test whether the
14-
intervention Delta made to the input configuration X causes the model-under-test to produce the expected change
15-
in Y.
13+
A CausalTestCase extends the information held in a BaseTestCase. As well as storing the treatment and outcome
14+
variables, a CausalTestCase stores the values of these variables. Also the outcome variable and value are
15+
specified.
16+
17+
The goal of a CausalTestCase is to test whether the intervention made to the control via the treatment causes the
18+
model-under-test to produce the expected change. The CausalTestCase structure is designed for execution using the
19+
CausalTestEngine, using either execute_test() function to execute a single test case or packing CausalTestCases into
20+
a CausalTestSuite and executing them as a batch using the execute_test_suite() function.
1621
"""
1722

1823
def __init__(
1924
self,
20-
control_input_configuration: dict[Variable:Any],
25+
base_test_case: BaseTestCase,
2126
expected_causal_effect: CausalTestOutcome,
22-
outcome_variables: dict[Variable],
23-
treatment_input_configuration: dict[Variable:Any] = None,
27+
control_value: Any,
28+
treatment_value: Any = None,
2429
estimate_type: str = "ate",
2530
effect_modifier_configuration: dict[Variable:Any] = None,
26-
effect: str = "total",
2731
):
2832
"""
29-
When a CausalTestCase is initialised, it takes the intervention and applies it to the input configuration to
30-
create two distinct input configurations: a control input configuration and a treatment input configuration.
31-
The former is the input configuration before applying the intervention and the latter is the input configuration
32-
after applying the intervention.
33-
34-
:param control_input_configuration: The input configuration representing the control values of the treatment
35-
variables.
36-
:param treatment_input_configuration: The input configuration representing the treatment values of the treatment
37-
variables. That is, the input configuration *after* applying the intervention.
33+
:param base_test_case: A BaseTestCase object consisting of a treatment variable, outcome variable and effect
34+
:param expected_causal_effect: The expected causal effect (Positive, Negative, No Effect).
35+
:param control_value: The control value for the treatment variable (before intervention).
36+
:param treatment_value: The treatment value for the treatment variable (after intervention).
37+
:param estimate_type: A string which denotes the type of estimate to return
38+
:param effect_modifier_configuration:
3839
"""
39-
self.control_input_configuration = control_input_configuration
40+
self.base_test_case = base_test_case
41+
self.control_value = control_value
4042
self.expected_causal_effect = expected_causal_effect
41-
self.outcome_variables = outcome_variables
42-
self.treatment_input_configuration = treatment_input_configuration
43+
self.outcome_variable = base_test_case.outcome_variable
44+
self.treatment_variable = base_test_case.treatment_variable
45+
self.treatment_value = treatment_value
4346
self.estimate_type = estimate_type
44-
self.effect = effect
47+
self.effect = base_test_case.effect
48+
4549
if effect_modifier_configuration:
4650
self.effect_modifier_configuration = effect_modifier_configuration
4751
else:
4852
self.effect_modifier_configuration = dict()
49-
assert (
50-
self.control_input_configuration.keys() == self.treatment_input_configuration.keys()
51-
), "Control and treatment input configurations must have the same keys."
5253

53-
def get_treatment_variables(self):
54-
"""Return a list of the treatment variables (as strings) for this causal test case."""
55-
return [v.name for v in self.control_input_configuration]
54+
def get_treatment_variable(self):
55+
"""Return the treatment variable name (as string) for this causal test case"""
56+
return self.treatment_variable.name
5657

57-
def get_outcome_variables(self):
58-
"""Return a list of the outcome variables (as strings) for this causal test case."""
59-
return [v.name for v in self.outcome_variables]
58+
def get_outcome_variable(self):
59+
"""Return the outcome variable name (as string) for this causal test case."""
60+
return self.outcome_variable.name
6061

61-
def get_control_values(self):
62-
"""Return a list of the control values for each treatment variable in this causal test case."""
63-
return list(self.control_input_configuration.values())
62+
def get_control_value(self):
63+
"""Return a the control value of the treatment variable in this causal test case."""
64+
return self.control_value
6465

65-
def get_treatment_values(self):
66-
"""Return a list of the treatment values for each treatment variable in this causal test case."""
67-
return list(self.treatment_input_configuration.values())
66+
def get_treatment_value(self):
67+
"""Return the treatment value of the treatment variable in this causal test case."""
68+
return self.treatment_value
6869

6970
def __str__(self):
70-
treatment_config = {k.name: v for k, v in self.treatment_input_configuration.items()}
71-
control_config = {k.name: v for k, v in self.control_input_configuration.items()}
71+
treatment_config = {self.treatment_variable.name: self.treatment_value}
72+
control_config = {self.treatment_variable.name: self.control_value}
73+
outcome_variable = {self.outcome_variable}
7274
return (
7375
f"Running {treatment_config} instead of {control_config} should cause the following "
74-
f"changes to {self.outcome_variables}: {self.expected_causal_effect}."
76+
f"changes to {outcome_variable}: {self.expected_causal_effect}."
7577
)

0 commit comments

Comments
 (0)