Skip to content

Commit 24316fe

Browse files
Merge pull request #100 from CITCOM-project/refactor-load-data
Refactor the load data function of the CausalTestEngine so it is not dependant on a CausalTestCase
2 parents 30eaaa4 + f5fedc9 commit 24316fe

File tree

9 files changed

+88
-102
lines changed

9 files changed

+88
-102
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ The causal testing framework has three core components:
1414
1. Using the causal DAG, identify an estimand for the effect of the intervention on the output of interest. That is, a statistical procedure capable of estimating the causal effect of the intervention on the output.
1515
2. Collect the data to which the statistical procedure will be applied (see Data collection below).
1616
3. Apply a statistical model (e.g. linear regression or causal forest) to the data to obtain a point estimate for the causal effect. Depending on the estimator used, confidence intervals may also be obtained at a specified confidence level e.g. 0.05 corresponds to 95% confidence intervals (optional).
17-
4. Return the casual test result including a point estimate and 95% confidence intervals, usally quantifying the average treatment effect (ATE).
17+
4. Return the casual test result including a point estimate and 95% confidence intervals, usually quantifying the average treatment effect (ATE).
1818
5. Implement and apply a test oracle to the causal test result - that is, a procedure that determines whether the test should pass or fail based on the results. In the simplest case, this takes the form of an assertion which compares the point estimate to the expected causal effect specified in the causal test case.
1919

2020
3. [Data collection](causal_testing/data_collection/README.md): Data for the system-under-test can be collected in two ways: experimentally or observationally. The former involves executing the system-under-test under controlled conditions which, by design, isolate the causal effect of interest (accurate but expensive), while the latter involves collecting suitable previous execution data and utilising our causal knowledge to draw causal inferences (potentially less accurate but efficient). To collect experimental data, the user must implement a single method which runs the system-under-test with a given input configuration. On the other hand, when dealing with observational data, we automatically check whether the data is suitable for the identified estimand in two steps. First, confirm whether the data contains a column for each variable in the causal DAG. Second, we check for [positivity violations](https://www.youtube.com/watch?v=4xc8VkrF98w). If there are positivity violations, we can provide instructions for an execution that will fill the gap (future work).
@@ -84,10 +84,10 @@ data_collector = ObservationalDataCollector(modelling_scenario, data_csv_path)
8484
The actual running of the tests is done using the `CausalTestEngine` class. This is still a work in progress and may change in the future to improve ease of use, but currently proceeds as follows.
8585

8686
```{python}
87-
causal_test_engine = CausalTestEngine(causal_test_case, causal_specification, data_collector) # Instantiate the causal test engine
88-
minimal_adjustment_set = causal_test_engine.load_data(data_csv_path, index_col=0) # Calculate the adjustment set
87+
causal_test_engine = CausalTestEngine(causal_specification, data_collector, data_csv_path, index_col=0) # Instantiate the causal test engine
88+
causal_test_engine.identification(causal_test_case) #Perform identification and produce the minimum adjustment set
8989
treatment_vars = list(causal_test_case.treatment_input_configuration)
90-
minimal_adjustment_set = minimal_adjustment_set - set([v.name for v in treatment_vars]) # Remove the treatment variables from the adjustment set. This is necessary for causal inference to work properly.
90+
minimal_adjustment_set = causal_test_engine.minimal_adjustment_set - set([v.name for v in treatment_vars]) # Remove the treatment variables from the adjustment set. This is necessary for causal inference to work properly.
9191
```
9292

9393
Whether using fresh or pre-existing data, a key aspect of causal inference is estimation. To actually execute a test, we need an estimator. We currently support two estimators: linear regression and causal forest. These can simply be instantiated as per the [documentation](https://causal-testing-framework.readthedocs.io/en/latest/autoapi/causal_testing/testing/estimators/index.html).

causal_testing/data_collection/data_collector.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def run_system_with_input_configuration(self, input_configuration: dict) -> pd.D
119119
specified input configuration.
120120
"""
121121

122+
122123
class ObservationalDataCollector(DataCollector):
123124
"""A data collector that extracts data that is relevant to the specified scenario from a csv of execution data."""
124125

causal_testing/json_front/json_class.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
169169

170170
causal_test_engine, estimation_model = self._setup_test(causal_test_case, estimator)
171171
causal_test_result = causal_test_engine.execute_test(
172-
estimation_model, estimate_type=causal_test_case.estimate_type
172+
estimation_model, causal_test_case, estimate_type=causal_test_case.estimate_type
173173
)
174174

175175
test_passes = causal_test_case.expected_causal_effect.apply(causal_test_result)
@@ -203,10 +203,10 @@ def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) ->
203203
- estimation_model - Estimator instance for the test being run
204204
"""
205205
data_collector = ObservationalDataCollector(self.modelling_scenario, self.data_path)
206-
causal_test_engine = CausalTestEngine(causal_test_case, self.causal_specification, data_collector)
207-
minimal_adjustment_set = causal_test_engine.load_data(index_col=0)
206+
causal_test_engine = CausalTestEngine(self.causal_specification, data_collector, index_col=0)
207+
causal_test_engine.identification(causal_test_case)
208208
treatment_vars = list(causal_test_case.treatment_input_configuration)
209-
minimal_adjustment_set = minimal_adjustment_set - {v.name for v in treatment_vars}
209+
minimal_adjustment_set = causal_test_engine.minimal_adjustment_set - {v.name for v in treatment_vars}
210210
estimation_model = estimator(
211211
(list(treatment_vars)[0].name,),
212212
[causal_test_case.treatment_input_configuration[v] for v in treatment_vars][0],

causal_testing/testing/causal_test_engine.py

Lines changed: 39 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -29,25 +29,9 @@ class CausalTestEngine:
2929
provided as an instance of the CausalTestResult class.
3030
(4) Define a test oracle procedure which uses the causal test results to determine whether the intervention has
3131
had the anticipated causal effect. This should assign a pass/fail value to the CausalTestResult.
32-
"""
33-
34-
def __init__(
35-
self, causal_test_case: CausalTestCase, causal_specification: CausalSpecification, data_collector: DataCollector
36-
):
37-
self.causal_test_case = causal_test_case
38-
self.treatment_variables = list(self.causal_test_case.control_input_configuration)
39-
self.casual_dag, self.scenario = (
40-
causal_specification.causal_dag,
41-
causal_specification.scenario,
42-
)
43-
self.data_collector = data_collector
44-
self.scenario_execution_data_df = pd.DataFrame()
45-
46-
def load_data(self, **kwargs):
47-
"""Load execution data corresponding to the causal test case into a pandas dataframe and return the minimal
48-
adjustment set.
4932
50-
Data can be loaded in two ways:
33+
Data is loaded as part of the "__init__" function
34+
Data can be loaded in two ways:
5135
(1) Experimentally - the model is executed with the treatment and control input configurations under
5236
conditions that guarantee the observed change in outcome must be caused by the change in input
5337
(intervention).
@@ -57,30 +41,43 @@ def load_data(self, **kwargs):
5741
5842
After the data is loaded, both are treated in the same way and, provided the identifiability and modelling
5943
assumptions hold, can be used to estimate the causal effect for the causal test case.
44+
"""
6045

61-
:return self: Update the causal test case's execution data dataframe.
46+
def __init__(self, causal_specification: CausalSpecification, data_collector: DataCollector, **kwargs):
47+
self.casual_dag, self.scenario = (
48+
causal_specification.causal_dag,
49+
causal_specification.scenario,
50+
)
51+
self.data_collector = data_collector
52+
self.scenario_execution_data_df = self.data_collector.collect_data(**kwargs)
53+
self.minimal_adjustment_set = set()
54+
55+
def identification(self, causal_test_case):
56+
"""Identify and return the minimum adjustment set
57+
58+
:param causal_test_case: Causal test Case to get the minimum adjustment set from
6259
:return minimal_adjustment_set: The smallest set of variables which can be adjusted for to obtain a causal
6360
estimate as opposed to a purely associational estimate.
6461
"""
6562

66-
self.scenario_execution_data_df = self.data_collector.collect_data(**kwargs)
67-
6863
minimal_adjustment_sets = []
69-
if self.causal_test_case.effect == "total":
64+
treatment_variables = list(causal_test_case.control_input_configuration)
65+
if causal_test_case.effect == "total":
7066
minimal_adjustment_sets = self.casual_dag.enumerate_minimal_adjustment_sets(
71-
[v.name for v in self.treatment_variables], [v.name for v in self.causal_test_case.outcome_variables]
67+
[v.name for v in treatment_variables], [v.name for v in causal_test_case.outcome_variables]
7268
)
73-
elif self.causal_test_case.effect == "direct":
69+
elif causal_test_case.effect == "direct":
7470
minimal_adjustment_sets = self.casual_dag.direct_effect_adjustment_sets(
75-
[v.name for v in self.treatment_variables], [v.name for v in self.causal_test_case.outcome_variables]
71+
[v.name for v in treatment_variables], [v.name for v in causal_test_case.outcome_variables]
7672
)
7773
else:
7874
raise ValueError("Causal effect should be 'total' or 'direct'")
7975

80-
minimal_adjustment_set = min(minimal_adjustment_sets, key=len)
81-
return minimal_adjustment_set
76+
self.minimal_adjustment_set = min(minimal_adjustment_sets, key=len)
8277

83-
def execute_test(self, estimator: Estimator, estimate_type: str = "ate") -> CausalTestResult:
78+
def execute_test(
79+
self, estimator: Estimator, causal_test_case: CausalTestCase, estimate_type: str = "ate"
80+
) -> CausalTestResult:
8481
"""Execute a causal test case and return the causal test result.
8582
8683
Test case execution proceeds with the following steps:
@@ -94,31 +91,31 @@ def execute_test(self, estimator: Estimator, estimate_type: str = "ate") -> Caus
9491
(7) Apply test oracle procedure to assign a pass/fail to the CausalTestResult and return.
9592
9693
:param estimator: A reference to an Estimator class.
94+
:param causal_test_case: The CausalTestCase object to be tested
9795
:param estimate_type: A string which denotes the type of estimate to return, ATE or CATE.
9896
:return causal_test_result: A CausalTestResult for the executed causal test case.
9997
"""
10098
if self.scenario_execution_data_df.empty:
10199
raise Exception("No data has been loaded. Please call load_data prior to executing a causal test case.")
102100
if estimator.df is None:
103101
estimator.df = self.scenario_execution_data_df
104-
treatments = [v.name for v in self.treatment_variables]
105-
outcomes = [v.name for v in self.causal_test_case.outcome_variables]
106-
minimal_adjustment_sets = self.casual_dag.enumerate_minimal_adjustment_sets(treatments, outcomes)
107-
minimal_adjustment_set = min(minimal_adjustment_sets, key=len)
102+
treatment_variables = list(causal_test_case.control_input_configuration)
103+
treatments = [v.name for v in treatment_variables]
104+
outcomes = [v.name for v in causal_test_case.outcome_variables]
108105

109106
logger.info("treatments: %s", treatments)
110107
logger.info("outcomes: %s", outcomes)
111-
logger.info("minimal_adjustment_set: %s", minimal_adjustment_set)
108+
logger.info("minimal_adjustment_set: %s", self.minimal_adjustment_set)
112109

113-
minimal_adjustment_set = minimal_adjustment_set - {
114-
v.name for v in self.causal_test_case.control_input_configuration
110+
minimal_adjustment_set = self.minimal_adjustment_set - {
111+
v.name for v in causal_test_case.control_input_configuration
115112
}
116-
minimal_adjustment_set = minimal_adjustment_set - {v.name for v in self.causal_test_case.outcome_variables}
113+
minimal_adjustment_set = minimal_adjustment_set - {v.name for v in causal_test_case.outcome_variables}
117114
assert all(
118-
(v.name not in minimal_adjustment_set for v in self.causal_test_case.control_input_configuration)
115+
(v.name not in minimal_adjustment_set for v in causal_test_case.control_input_configuration)
119116
), "Treatment vars in adjustment set"
120117
assert all(
121-
(v.name not in minimal_adjustment_set for v in self.causal_test_case.outcome_variables)
118+
(v.name not in minimal_adjustment_set for v in causal_test_case.outcome_variables)
122119
), "Outcome vars in adjustment set"
123120

124121
variables_for_positivity = list(minimal_adjustment_set) + treatments + outcomes
@@ -142,7 +139,7 @@ def execute_test(self, estimator: Estimator, estimate_type: str = "ate") -> Caus
142139
control_value=estimator.control_values,
143140
adjustment_set=estimator.adjustment_set,
144141
ate=cates_df,
145-
effect_modifier_configuration=self.causal_test_case.effect_modifier_configuration,
142+
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
146143
confidence_intervals=confidence_intervals,
147144
)
148145
elif estimate_type == "risk_ratio":
@@ -155,7 +152,7 @@ def execute_test(self, estimator: Estimator, estimate_type: str = "ate") -> Caus
155152
control_value=estimator.control_values,
156153
adjustment_set=estimator.adjustment_set,
157154
ate=risk_ratio,
158-
effect_modifier_configuration=self.causal_test_case.effect_modifier_configuration,
155+
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
159156
confidence_intervals=confidence_intervals,
160157
)
161158
elif estimate_type == "ate":
@@ -168,7 +165,7 @@ def execute_test(self, estimator: Estimator, estimate_type: str = "ate") -> Caus
168165
control_value=estimator.control_values,
169166
adjustment_set=estimator.adjustment_set,
170167
ate=ate,
171-
effect_modifier_configuration=self.causal_test_case.effect_modifier_configuration,
168+
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
172169
confidence_intervals=confidence_intervals,
173170
)
174171
# causal_test_result = CausalTestResult(minimal_adjustment_set, ate, confidence_intervals)
@@ -183,7 +180,7 @@ def execute_test(self, estimator: Estimator, estimate_type: str = "ate") -> Caus
183180
control_value=estimator.control_values,
184181
adjustment_set=estimator.adjustment_set,
185182
ate=ate,
186-
effect_modifier_configuration=self.causal_test_case.effect_modifier_configuration,
183+
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
187184
confidence_intervals=confidence_intervals,
188185
)
189186
# causal_test_result = CausalTestResult(minimal_adjustment_set, ate, confidence_intervals)

examples/covasim_/doubling_beta/causal_test_beta.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ def doubling_beta_CATE_on_csv(observational_data_path: str, simulate_counterfact
4242
results_dict = {'association': {},
4343
'causation': {}}
4444

45-
# Read in the observational data and perform identification
45+
# Read in the observational data, perform identification, and setup the causal_test_engine
4646
past_execution_df = pd.read_csv(observational_data_path)
47-
_, causal_test_engine = identification(observational_data_path)
47+
_, causal_test_engine, causal_test_case = engine_setup(observational_data_path)
4848

4949
linear_regression_estimator = LinearRegressionEstimator(('beta',), 0.032, 0.016,
5050
{'avg_age', 'contacts'}, # We use custom adjustment set
@@ -53,15 +53,15 @@ def doubling_beta_CATE_on_csv(observational_data_path: str, simulate_counterfact
5353

5454
# Add squared terms for beta, since it has a quadratic relationship with cumulative infections
5555
linear_regression_estimator.add_squared_term_to_df('beta')
56-
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, 'ate')
56+
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case, 'ate')
5757

5858
# Repeat for association estimate (no adjustment)
5959
no_adjustment_linear_regression_estimator = LinearRegressionEstimator(('beta',), 0.032, 0.016,
6060
set(),
6161
('cum_infections',),
6262
df=past_execution_df)
6363
no_adjustment_linear_regression_estimator.add_squared_term_to_df('beta')
64-
association_test_result = causal_test_engine.execute_test(no_adjustment_linear_regression_estimator, 'ate')
64+
association_test_result = causal_test_engine.execute_test(no_adjustment_linear_regression_estimator, causal_test_case, 'ate')
6565

6666
# Store results for plotting
6767
results_dict['association'] = {'ate': association_test_result.ate,
@@ -83,7 +83,7 @@ def doubling_beta_CATE_on_csv(observational_data_path: str, simulate_counterfact
8383
('cum_infections',),
8484
df=counterfactual_past_execution_df)
8585
counterfactual_linear_regression_estimator.add_squared_term_to_df('beta')
86-
counterfactual_causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, 'ate')
86+
counterfactual_causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case, 'ate')
8787
results_dict['counterfactual'] = {'ate': counterfactual_causal_test_result.ate,
8888
'cis': counterfactual_causal_test_result.confidence_intervals,
8989
'df': counterfactual_past_execution_df}
@@ -179,7 +179,7 @@ def doubling_beta_CATEs(observational_data_path: str, simulate_counterfactual: b
179179
age_contact_fig.savefig(outpath_base_str + "age_contact_executions.pdf", format="pdf")
180180

181181

182-
def identification(observational_data_path):
182+
def engine_setup(observational_data_path):
183183
# 1. Read in the Causal DAG
184184
causal_dag = CausalDAG('dag.dot')
185185

@@ -213,12 +213,12 @@ def identification(observational_data_path):
213213
data_collector = ObservationalDataCollector(scenario, observational_data_path)
214214

215215
# 7. Create an instance of the causal test engine
216-
causal_test_engine = CausalTestEngine(causal_test_case, causal_specification, data_collector)
216+
causal_test_engine = CausalTestEngine(causal_specification, data_collector)
217217

218218
# 8. Obtain the minimal adjustment set for the causal test case from the causal DAG
219-
minimal_adjustment_set = causal_test_engine.load_data(index_col=0)
219+
causal_test_engine.identification(causal_test_case)
220220

221-
return minimal_adjustment_set, causal_test_engine
221+
return causal_test_engine.minimal_adjustment_set, causal_test_engine, causal_test_case
222222

223223

224224
def plot_doubling_beta_CATEs(results_dict, title, figure=None, axes=None, row=None, col=None):

0 commit comments

Comments
 (0)