Skip to content

Commit 0c194b5

Browse files
author
Andrew Clark
authored
Merge pull request #61 from CITCOM-project/lr91_example
Added and refactored Covasim and LR91 case studies
2 parents 06f72c4 + ff34ff9 commit 0c194b5

File tree

9 files changed

+351
-1362
lines changed

9 files changed

+351
-1362
lines changed

causal_testing/data_collection/data_collector.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -90,23 +90,17 @@ def __init__(self, scenario: Scenario, control_input_configuration: dict, treatm
9090
self.treatment_input_configuration = treatment_input_configuration
9191
self.n_repeats = n_repeats
9292

93-
@abstractmethod
9493
def collect_data(self, **kwargs) -> pd.DataFrame:
95-
"""Populate the dataframe with execution data.
94+
"""Run the system-under-test with control and treatment input configurations to obtain experimental data in
95+
which the causal effect of interest is isolated by design.
9696
9797
:return: A pandas dataframe containing execution data for the system-under-test in both control and treatment
9898
executions.
9999
"""
100-
# Check runtime configs to make sure they don't violate constraints
101-
control_df = self.run_system_with_input_configuration(
102-
self.filter_valid_data(self.control_input_configuration, check_pos=False)
103-
)
104-
treatment_df = self.run_system_with_input_configuration(
105-
self.filter_valid_data(self.treatment_input_configuration, check_pos=False)
106-
)
107-
108-
# Need to check final output too just in case we have constraints on output variables
109-
return self.filter_valid_data(pd.concat([control_df, treatment_df], keys=["control", "treatment"]))
100+
control_results_df = self.run_system_with_input_configuration(self.control_input_configuration)
101+
treatment_results_df = self.run_system_with_input_configuration(self.treatment_input_configuration)
102+
results_df = pd.concat([control_results_df, treatment_results_df], ignore_index=True)
103+
return results_df
110104

111105
@abstractmethod
112106
def run_system_with_input_configuration(self, input_configuration: dict) -> pd.DataFrame:

causal_testing/testing/causal_test_outcome.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,7 @@ class NoEffect(CausalTestOutcome):
108108
"""An extension of TestOutcome representing that the expected causal effect should be zero."""
109109

110110
def apply(self, res: CausalTestResult) -> bool:
111-
return res.ci_low() < 0 < res.ci_high()
112-
111+
return (res.ci_low() < 0 < res.ci_high()) or (abs(res.ate) < 1e-10)
113112

114113
def __str__(self):
115114
return "Unchanged"

examples/covasim_/causal_test_beta.py

Lines changed: 27 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import matplotlib.pyplot as plt
22
import pandas as pd
33
import numpy as np
4-
import glob
54
from causal_testing.specification.causal_dag import CausalDAG
65
from causal_testing.specification.scenario import Scenario
76
from causal_testing.specification.variable import Input, Output
@@ -11,10 +10,10 @@
1110
from causal_testing.testing.causal_test_outcome import Positive
1211
from causal_testing.testing.intervention import Intervention
1312
from causal_testing.testing.causal_test_engine import CausalTestEngine
14-
from causal_testing.testing.estimators import LinearRegressionEstimator, CausalForestEstimator
13+
from causal_testing.testing.estimators import LinearRegressionEstimator
1514
from matplotlib.pyplot import rcParams
1615

17-
# Make the plots all fancy
16+
# Make all graphs publication quality
1817
plt.rcParams["figure.figsize"] = (8, 8)
1918
rc_fonts = {
2019
"font.size": 8,
@@ -28,25 +27,15 @@
2827
OBSERVATIONAL_DATA_PATH = "./data/10k_observational_data.csv"
2928

3029

31-
def concatenate_csvs_in_directory(directory_path, output_path):
32-
""" Concatenate all csvs in a given directory, assuming all csvs share the same header. This will stack the csvs
33-
vertically and will not reset the index.
34-
"""
35-
dfs = []
36-
for csv_name in glob.glob(directory_path):
37-
dfs.append(pd.read_csv(csv_name, index_col=0))
38-
full_df = pd.concat(dfs, ignore_index=True)
39-
full_df.to_csv(output_path)
40-
41-
42-
def CATE_on_csv(observational_data_path: str, simulate_counterfactuals: bool = False, verbose: bool = False):
30+
def doubling_beta_CATE_on_csv(observational_data_path: str, simulate_counterfactuals: bool = False,
31+
verbose: bool = False):
4332
""" Compute the CATE of increasing beta from 0.016 to 0.032 on cum_infections using the dataframe
4433
loaded from the specified path. Additionally simulate the counterfactuals by repeating the analysis
4534
after removing rows with beta==0.032.
4635
4736
:param observational_data_path: Path to csv containing observational data for analysis.
48-
:param simulate_counterfactuals: Whether or not to repeat analysis with counterfactuals.
49-
:param verbose: Whether or not to print verbose details (causal test results).
37+
:param simulate_counterfactuals: Whether to repeat analysis with counterfactuals.
38+
:param verbose: Whether to print verbose details (causal test results).
5039
:return results_dict: A nested dictionary containing results (ate and confidence intervals)
5140
for association, causation, and counterfactual (if completed).
5241
"""
@@ -86,7 +75,7 @@ def CATE_on_csv(observational_data_path: str, simulate_counterfactuals: bool = F
8675
print(f"Association:\n{association_test_result}")
8776
print(f"Causation:\n{causal_test_result}")
8877

89-
# Repeat causal inference after deleting all rows with treatment value
78+
# Repeat causal inference after deleting all rows with treatment value to obtain counterfactual inferences
9079
if simulate_counterfactuals:
9180
counterfactual_past_execution_df = past_execution_df[past_execution_df['beta'] != 0.032]
9281
counterfactual_linear_regression_estimator = LinearRegressionEstimator(('beta',), 0.032, 0.016,
@@ -104,14 +93,14 @@ def CATE_on_csv(observational_data_path: str, simulate_counterfactuals: bool = F
10493
return results_dict
10594

10695

107-
def manual_CATE(observational_data_path: str, simulate_counterfactual: bool = False, verbose: bool = False):
96+
def doubling_beta_CATEs(observational_data_path: str, simulate_counterfactual: bool = False, verbose: bool = False):
10897
""" Compute the CATE for the effect of doubling beta across simulations with different age demographics.
10998
To compute the CATE, this method splits the observational data into high and low age data and computes the
11099
ATE using this data and a linear regression model.
111100
112-
Since this method already adjusts for age, adding age as
113-
an adjustment to the LR model will have no impact. However, adding contacts as an adjustment should reduce
114-
bias and reveal the average causal effect of doubling beta in simulations of a particular age demographic. """
101+
Since this method already adjusts for age, adding age as an adjustment to the LR model will have no impact.
102+
However, adding contacts as an adjustment should reduce bias and reveal the average causal effect of doubling beta
103+
in simulations of a particular age demographic. """
115104

116105
# Create separate subplots for each more specific causal question
117106
all_fig, all_axes = plt.subplots(1, 1, figsize=(4, 3), squeeze=False)
@@ -121,8 +110,8 @@ def manual_CATE(observational_data_path: str, simulate_counterfactual: bool = Fa
121110
# Apply CT to get the ATE over all executions
122111
if verbose:
123112
print("Running causal tests for all data...")
124-
all_data_results_dict = CATE_on_csv(observational_data_path, simulate_counterfactual, verbose=False)
125-
plot_manual_CATE_result(all_data_results_dict, "All Data", all_fig, all_axes, row=0, col=0)
113+
all_data_results_dict = doubling_beta_CATE_on_csv(observational_data_path, simulate_counterfactual, verbose=False)
114+
plot_doubling_beta_CATEs(all_data_results_dict, "All Data", all_fig, all_axes, row=0, col=0)
126115

127116
# Split data into age-specific strata
128117
past_execution_df = pd.read_csv(observational_data_path)
@@ -141,13 +130,14 @@ def manual_CATE(observational_data_path: str, simulate_counterfactual: bool = Fa
141130
"./data/bessemer/older_population.csv"]
142131

143132
for col, separated_observational_data_path in enumerate(separated_observational_data_paths):
144-
age_data_results_dict = CATE_on_csv(separated_observational_data_path, simulate_counterfactual, verbose=False)
133+
age_data_results_dict = doubling_beta_CATE_on_csv(separated_observational_data_path, simulate_counterfactual,
134+
verbose=False)
145135
age_stratified_df = pd.read_csv(separated_observational_data_path)
146136
age_stratified_df_avg_age = round(age_stratified_df["avg_age"].mean(), 1)
147137
if verbose:
148138
print(f"Running causal tests for data with average age of {age_stratified_df_avg_age}")
149-
plot_manual_CATE_result(age_data_results_dict, f"Age={age_stratified_df_avg_age}", age_fig, age_axes, row=0,
150-
col=col)
139+
plot_doubling_beta_CATEs(age_data_results_dict, f"Age={age_stratified_df_avg_age}", age_fig, age_axes, row=0,
140+
col=col)
151141

152142
# Split df into contact-specific strata
153143
min_contacts = np.floor(age_stratified_df['contacts'].min())
@@ -167,16 +157,17 @@ def manual_CATE(observational_data_path: str, simulate_counterfactual: bool = Fa
167157

168158
# Compute the CATE for each age-contact group
169159
for row, age_contact_data_path in enumerate(contact_observational_data_paths):
170-
age_contact_data_results_dict = CATE_on_csv(age_contact_data_path, simulate_counterfactual, verbose=False)
160+
age_contact_data_results_dict = doubling_beta_CATE_on_csv(age_contact_data_path, simulate_counterfactual,
161+
verbose=False)
171162
age_contact_stratified_df = pd.read_csv(age_contact_data_path)
172163
age_contact_stratified_df_avg_contacts = round(age_contact_stratified_df["contacts"].mean(), 1)
173164
if verbose:
174165
print(f"Running causal tests for data with average age of {age_stratified_df_avg_age} and "
175166
f"{age_contact_stratified_df_avg_contacts} average household contacts.")
176-
plot_manual_CATE_result(age_contact_data_results_dict,
177-
f"Age={age_stratified_df_avg_age} "
178-
f"Contacts={age_contact_stratified_df_avg_contacts}",
179-
age_contact_fig, age_contact_axes, row=row, col=col)
167+
plot_doubling_beta_CATEs(age_contact_data_results_dict,
168+
f"Age={age_stratified_df_avg_age} "
169+
f"Contacts={age_contact_stratified_df_avg_contacts}",
170+
age_contact_fig, age_contact_axes, row=row, col=col)
180171

181172
# Save plots
182173
if simulate_counterfactual:
@@ -230,32 +221,7 @@ def identification(observational_data_path):
230221
return minimal_adjustment_set, causal_test_engine
231222

232223

233-
def causal_forest_CATE(observational_data_path):
234-
_, causal_test_engine = identification(observational_data_path)
235-
causal_forest_estimator = CausalForestEstimator(
236-
treatment=('beta',),
237-
treatment_values=0.032,
238-
control_values=0.016,
239-
adjustment_set={'avg_age', 'contacts'},
240-
outcome=('cum_infections',),
241-
effect_modifiers={causal_test_engine.scenario.variables['avg_age']})
242-
causal_forest_estimator_no_adjustment = CausalForestEstimator(
243-
treatment=('beta',),
244-
treatment_values=0.032,
245-
control_values=0.016,
246-
adjustment_set=set(),
247-
outcome=('cum_infections',),
248-
effect_modifiers={causal_test_engine.scenario.variables['avg_age']})
249-
250-
# 10. Execute the test case and compare the results
251-
causal_test_result = causal_test_engine.execute_test(causal_forest_estimator, 'cate')
252-
association_test_result = causal_test_engine.execute_test(causal_forest_estimator_no_adjustment, 'cate')
253-
observational_data = pd.read_csv(observational_data_path)
254-
plot_causal_forest_result(causal_test_result, observational_data, "Causal Forest Adjusted for Age and Contacts.")
255-
plot_causal_forest_result(association_test_result, observational_data, "Causal Forest Without Adjustment.")
256-
257-
258-
def plot_manual_CATE_result(results_dict, title, figure=None, axes=None, row=None, col=None):
224+
def plot_doubling_beta_CATEs(results_dict, title, figure=None, axes=None, row=None, col=None):
259225
# Get the CATE as a percentage for association and causation
260226
ate = results_dict['causation']['ate']
261227
association_ate = results_dict['association']['ate']
@@ -283,10 +249,10 @@ def plot_manual_CATE_result(results_dict, title, figure=None, axes=None, row=Non
283249
ys = [association_percentage_ate, percentage_ate]
284250
yerrs = [percentage_association_errs, percentage_causal_errs]
285251
xticks = ['Association', 'Causation']
286-
print(f"Causal ATE: {percentage_ate} {percentage_causal_ate_cis}")
287-
print(f"Causal executions: {len(causation_df)}")
288252
print(f"Association ATE: {association_percentage_ate} {percentage_association_ate_cis}")
289253
print(f"Association executions: {len(association_df)}")
254+
print(f"Causal ATE: {percentage_ate} {percentage_causal_ate_cis}")
255+
print(f"Causal executions: {len(causation_df)}")
290256
if 'counterfactual' in results_dict.keys():
291257
cf_ate = results_dict['counterfactual']['ate']
292258
cf_df = results_dict['counterfactual']['df']
@@ -309,33 +275,5 @@ def plot_manual_CATE_result(results_dict, title, figure=None, axes=None, row=Non
309275
figure.supylabel(r"\% Change in Cumulative Infections (ATE)", fontsize=10)
310276

311277

312-
def plot_causal_forest_result(causal_forest_test_result, previous_data_df, title=None, filter_data_by_variant=False):
313-
sorted_causal_forest_test_result = causal_forest_test_result.ate.sort_index()
314-
no_avg_age_causal_forest_test_result = sorted_causal_forest_test_result.drop(columns='avg_age')
315-
observational_df_with_results = previous_data_df.join(no_avg_age_causal_forest_test_result)
316-
observational_df_with_results['percentage_increase'] = \
317-
(observational_df_with_results['cate'] / observational_df_with_results['cum_infections']) * 100
318-
fig, ax = plt.subplots()
319-
if filter_data_by_variant:
320-
observational_df_with_results = observational_df_with_results.loc[observational_df_with_results['variants']
321-
== 'beta']
322-
for location in observational_df_with_results.location.unique():
323-
location_variant_df = observational_df_with_results.loc[observational_df_with_results['location'] == location]
324-
xs = location_variant_df['avg_age']
325-
ys = location_variant_df['percentage_increase']
326-
ax.scatter(xs, ys, s=1, alpha=.3, label=location)
327-
ax.set_ylabel("% change in cumulative infections")
328-
ax.set_xlabel("Average age")
329-
ax.set_title(title)
330-
ax.set_ylim(0, 40)
331-
box = ax.get_position()
332-
ax.set_position([box.x0, box.y0 + box.height * 0.2, box.width, box.height * 0.8])
333-
plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.2), fancybox=True, ncol=4)
334-
plt.show()
335-
336-
337278
if __name__ == "__main__":
338-
# concatenate_csvs_in_directory("./data/bessemer/custom_variants/thursday_31st_march/2k_executions/2k_data/*.csv",
339-
# "./data/10k_observational_data.csv")
340-
manual_CATE(OBSERVATIONAL_DATA_PATH, True, True)
341-
# causal_forest_CATE(OBSERVATIONAL_DATA_PATH)
279+
doubling_beta_CATEs(OBSERVATIONAL_DATA_PATH, True, True)

0 commit comments

Comments
 (0)