Skip to content

Commit f99c46a

Browse files
Merge pull request #118 from CITCOM-project/causal_test_case_refactor
Causal test case refactor
2 parents 92b6250 + 61e7def commit f99c46a

File tree

13 files changed

+161
-150
lines changed

13 files changed

+161
-150
lines changed

causal_testing/generation/abstract_causal_test_case.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
class AbstractCausalTestCase:
2121
"""
22-
An abstract test case serves as a generator for concrete test cases. Instead of having concrete conctrol
22+
An abstract test case serves as a generator for concrete test cases. Instead of having concrete control
2323
and treatment values, we instead just specify the intervention and the treatment variables. This then
2424
enables potentially infinite concrete test cases to be generated between different values of the treatment.
2525
"""
@@ -157,7 +157,7 @@ def _generate_concrete_tests(
157157
# Treatment run
158158
if rct:
159159
treatment_run = control_run.copy()
160-
treatment_run.update({k.name: v for k, v in concrete_test.treatment_input_configuration.items()})
160+
treatment_run.update({concrete_test.treatment_variable.name: concrete_test.treatment_value})
161161
treatment_run["bin"] = index
162162
runs.append(treatment_run)
163163

@@ -204,7 +204,7 @@ def generate_concrete_tests(
204204
runs = pd.concat([runs, runs_])
205205
assert concrete_tests_ not in concrete_tests, "Duplicate entries unlikely unless something went wrong"
206206

207-
control_configs = pd.DataFrame([test.control_input_configuration for test in concrete_tests])
207+
control_configs = pd.DataFrame([{test.treatment_variable: test.control_value} for test in concrete_tests])
208208
ks_stats = {
209209
var: stats.kstest(control_configs[var], var.distribution.cdf).statistic
210210
for var in control_configs.columns
@@ -227,8 +227,8 @@ def generate_concrete_tests(
227227
for var in effect_modifier_configs.columns
228228
}
229229
)
230-
control_values = [test.control_input_configuration[self.treatment_variable] for test in concrete_tests]
231-
treatment_values = [test.treatment_input_configuration[self.treatment_variable] for test in concrete_tests]
230+
control_values = [test.control_value for test in concrete_tests]
231+
treatment_values = [test.treatment_value for test in concrete_tests]
232232

233233
if self.treatment_variable.datatype is bool and set([(True, False), (False, True)]).issubset(
234234
set(zip(control_values, treatment_values))

causal_testing/json_front/json_class.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -202,12 +202,12 @@ def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) ->
202202
data_collector = ObservationalDataCollector(self.modelling_scenario, self.data_path)
203203
causal_test_engine = CausalTestEngine(self.causal_specification, data_collector, index_col=0)
204204
minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_test_case)
205-
treatment_vars = list(causal_test_case.treatment_input_configuration)
206-
minimal_adjustment_set = minimal_adjustment_set - {v.name for v in treatment_vars}
205+
treatment_var = causal_test_case.treatment_variable
206+
minimal_adjustment_set = minimal_adjustment_set - {treatment_var}
207207
estimation_model = estimator(
208-
(list(treatment_vars)[0].name,),
209-
[causal_test_case.treatment_input_configuration[v] for v in treatment_vars][0],
210-
[causal_test_case.control_input_configuration[v] for v in treatment_vars][0],
208+
(treatment_var.name,),
209+
causal_test_case.treatment_value,
210+
causal_test_case.control_value,
211211
minimal_adjustment_set,
212212
(causal_test_case.outcome_variable.name,),
213213
causal_test_engine.scenario_execution_data_df,

causal_testing/specification/causal_dag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,7 @@ def depends_on_outputs(self, node: Node, scenario: Scenario) -> bool:
471471
def identification(self, base_test_case):
472472
"""Identify and return the minimum adjustment set
473473
474-
:param base_test_case: A base test case class instance containing the outcome_variable and the
474+
:param base_test_case: A base test case instance containing the outcome_variable and the
475475
treatment_variable required for identification.
476476
:return minimal_adjustment_set: The smallest set of variables which can be adjusted for to obtain a causal
477477
estimate as opposed to a purely associational estimate.

causal_testing/testing/causal_test_case.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,15 @@
99

1010

1111
class CausalTestCase:
12-
""")
13-
A causal test case is a triple (X, Delta, Y), where X is an input configuration, Delta is an intervention, and
14-
Y is the expected causal effect on a particular output. The goal of a causal test case is to test whether the
15-
intervention Delta made to the input configuration X causes the model-under-test to produce the expected change
16-
in Y.
12+
"""
13+
A CausalTestCase extends the information held in a BaseTestCase. As well as storing the treatment and outcome
14+
variables, a CausalTestCase stores the values of these variables. Also the outcome variable and value are
15+
specified.
16+
17+
The goal of a CausalTestCase is to test whether the intervention made to the control via the treatment causes the
18+
model-under-test to produce the expected change. The CausalTestCase structure is designed for execution using the
19+
CausalTestEngine, using either execute_test() function to execute a single test case or packing CausalTestCases into
20+
a CausalTestSuite and executing them as a batch using the execute_test_suite() function.
1721
"""
1822

1923
def __init__(
@@ -26,50 +30,46 @@ def __init__(
2630
effect_modifier_configuration: dict[Variable:Any] = None,
2731
):
2832
"""
29-
When a CausalTestCase is initialised, it takes the intervention and applies it to the input configuration to
30-
create two distinct input configurations: a control input configuration and a treatment input configuration.
31-
The former is the input configuration before applying the intervention and the latter is the input configuration
32-
after applying the intervention.
33-
34-
:param control_input_configuration: The input configuration representing the control values of the treatment
35-
variables.
36-
:param treatment_input_configuration: The input configuration representing the treatment values of the treatment
37-
variables. That is, the input configuration *after* applying the intervention.
33+
:param base_test_case: A BaseTestCase object consisting of a treatment variable, outcome variable and effect
34+
:param expected_causal_effect: The expected causal effect (Positive, Negative, No Effect).
35+
:param control_value: The control value for the treatment variable (before intervention).
36+
:param treatment_value: The treatment value for the treatment variable (after intervention).
37+
:param estimate_type: A string which denotes the type of estimate to return
38+
:param effect_modifier_configuration:
3839
"""
3940
self.base_test_case = base_test_case
40-
self.control_input_configuration = {base_test_case.treatment_variable: control_value}
41+
self.control_value = control_value
4142
self.expected_causal_effect = expected_causal_effect
4243
self.outcome_variable = base_test_case.outcome_variable
43-
self.treatment_input_configuration = {base_test_case.treatment_variable: treatment_value}
44+
self.treatment_variable = base_test_case.treatment_variable
45+
self.treatment_value = treatment_value
4446
self.estimate_type = estimate_type
4547
self.effect = base_test_case.effect
48+
4649
if effect_modifier_configuration:
4750
self.effect_modifier_configuration = effect_modifier_configuration
4851
else:
4952
self.effect_modifier_configuration = dict()
50-
assert (
51-
self.control_input_configuration.keys() == self.treatment_input_configuration.keys()
52-
), "Control and treatment input configurations must have the same keys."
5353

54-
def get_treatment_variables(self):
55-
"""Return a list of the treatment variables (as strings) for this causal test case."""
56-
return [v.name for v in self.control_input_configuration]
54+
def get_treatment_variable(self):
55+
"""Return the treatment variable name (as string) for this causal test case"""
56+
return self.treatment_variable.name
5757

58-
def get_outcome_variables(self):
59-
"""Return a list of the outcome variables (as strings) for this causal test case."""
60-
return [self.outcome_variable.name]
58+
def get_outcome_variable(self):
59+
"""Return the outcome variable name (as string) for this causal test case."""
60+
return self.outcome_variable.name
6161

62-
def get_control_values(self):
63-
"""Return a list of the control values for each treatment variable in this causal test case."""
64-
return list(self.control_input_configuration.values())
62+
def get_control_value(self):
63+
"""Return a the control value of the treatment variable in this causal test case."""
64+
return self.control_value
6565

66-
def get_treatment_values(self):
67-
"""Return a list of the treatment values for each treatment variable in this causal test case."""
68-
return list(self.treatment_input_configuration.values())
66+
def get_treatment_value(self):
67+
"""Return the treatment value of the treatment variable in this causal test case."""
68+
return self.treatment_value
6969

7070
def __str__(self):
71-
treatment_config = {k.name: v for k, v in self.treatment_input_configuration.items()}
72-
control_config = {k.name: v for k, v in self.control_input_configuration.items()}
71+
treatment_config = {self.treatment_variable.name: self.treatment_value}
72+
control_config = {self.treatment_variable.name: self.control_value}
7373
outcome_variable = {self.outcome_variable}
7474
return (
7575
f"Running {treatment_config} instead of {control_config} should cause the following "

causal_testing/testing/causal_test_engine.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ def execute_test_suite(self, test_suite: CausalTestSuite) -> list[CausalTestResu
8888
causal_test_results = []
8989

9090
for test in tests:
91-
treatment_variable = list(test.treatment_input_configuration.keys())[0]
92-
treatment_value = list(test.treatment_input_configuration.values())[0]
93-
control_value = list(test.control_input_configuration.values())[0]
91+
treatment_variable = test.treatment_variable
92+
treatment_value = test.treatment_value
93+
control_value = test.control_value
9494
estimator = EstimatorClass(
9595
(treatment_variable.name,),
9696
treatment_value,
@@ -131,19 +131,16 @@ def execute_test(
131131
raise Exception("No data has been loaded. Please call load_data prior to executing a causal test case.")
132132
if estimator.df is None:
133133
estimator.df = self.scenario_execution_data_df
134-
treatment_variable = list(causal_test_case.control_input_configuration.keys())[0]
134+
treatment_variable = causal_test_case.treatment_variable
135135
treatments = treatment_variable.name
136136
outcome_variable = causal_test_case.outcome_variable
137137

138138
logger.info("treatments: %s", treatments)
139139
logger.info("outcomes: %s", outcome_variable)
140140
minimal_adjustment_set = self.causal_dag.identification(BaseTestCase(treatment_variable, outcome_variable))
141-
minimal_adjustment_set = minimal_adjustment_set - {v.name for v in causal_test_case.control_input_configuration}
141+
minimal_adjustment_set = minimal_adjustment_set - set(treatment_variable.name)
142142
minimal_adjustment_set = minimal_adjustment_set - set(outcome_variable.name)
143-
assert all(
144-
(v.name not in minimal_adjustment_set for v in causal_test_case.control_input_configuration)
145-
), "Treatment vars in adjustment set"
146-
assert outcome_variable not in minimal_adjustment_set, "Outcome vars in adjustment set"
143+
147144
variables_for_positivity = list(minimal_adjustment_set) + [treatment_variable.name] + [outcome_variable.name]
148145

149146
if self._check_positivity_violation(variables_for_positivity):
@@ -172,8 +169,8 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
172169
causal_test_result = CausalTestResult(
173170
treatment=estimator.treatment,
174171
outcome=estimator.outcome,
175-
treatment_value=estimator.treatment_values,
176-
control_value=estimator.control_values,
172+
treatment_value=estimator.treatment_value,
173+
control_value=estimator.control_value,
177174
adjustment_set=estimator.adjustment_set,
178175
test_value=TestValue("ate", cates_df),
179176
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
@@ -185,8 +182,8 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
185182
causal_test_result = CausalTestResult(
186183
treatment=estimator.treatment,
187184
outcome=estimator.outcome,
188-
treatment_value=estimator.treatment_values,
189-
control_value=estimator.control_values,
185+
treatment_value=estimator.treatment_value,
186+
control_value=estimator.control_value,
190187
adjustment_set=estimator.adjustment_set,
191188
test_value=TestValue("risk_ratio", risk_ratio),
192189
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
@@ -198,8 +195,8 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
198195
causal_test_result = CausalTestResult(
199196
treatment=estimator.treatment,
200197
outcome=estimator.outcome,
201-
treatment_value=estimator.treatment_values,
202-
control_value=estimator.control_values,
198+
treatment_value=estimator.treatment_value,
199+
control_value=estimator.control_value,
203200
adjustment_set=estimator.adjustment_set,
204201
test_value=TestValue("ate", ate),
205202
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
@@ -213,8 +210,8 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
213210
causal_test_result = CausalTestResult(
214211
treatment=estimator.treatment,
215212
outcome=estimator.outcome,
216-
treatment_value=estimator.treatment_values,
217-
control_value=estimator.control_values,
213+
treatment_value=estimator.treatment_value,
214+
control_value=estimator.control_value,
218215
adjustment_set=estimator.adjustment_set,
219216
test_value=TestValue("ate", ate),
220217
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,

causal_testing/testing/causal_test_suite.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
from collections import UserDict
2+
from typing import Type, Iterable
3+
from causal_testing.testing.base_test_case import BaseTestCase
4+
from causal_testing.testing.causal_test_case import CausalTestCase
5+
from causal_testing.testing.estimators import Estimator
26

37

48
class CausalTestSuite(UserDict):
@@ -13,15 +17,21 @@ class CausalTestSuite(UserDict):
1317
base_test_case's and execute each causal_test_case with each iterator.
1418
"""
1519

16-
def add_test_object(self, base_test_case, causal_test_case_list, estimators, estimate_type: str = "ate"):
20+
def add_test_object(
21+
self,
22+
base_test_case: BaseTestCase,
23+
causal_test_case_list: Iterable[CausalTestCase],
24+
estimators_classes: Iterable[Type[Estimator]],
25+
estimate_type: str = "ate",
26+
):
1727
"""
1828
A setter object to allow for the easy construction of the dictionary test suite structure
1929
2030
:param base_test_case: A BaseTestCase object consisting of a treatment variable, outcome variable and effect
2131
:param causal_test_case_list: A list of causal test cases to be executed
22-
:param estimators: A list of estimators, the execute_test_suite function in the TestEngine will produce a list
32+
:param estimators_classes: A list of estimator class references, the execute_test_suite function in the TestEngine will produce a list
2333
of test results for each estimator
2434
:param estimate_type: A string which denotes the type of estimate to return
2535
"""
26-
test_object = {"tests": causal_test_case_list, "estimators": estimators, "estimate_type": estimate_type}
36+
test_object = {"tests": causal_test_case_list, "estimators": estimators_classes, "estimate_type": estimate_type}
2737
self.data[base_test_case] = test_object

0 commit comments

Comments
 (0)