1
1
import matplotlib .pyplot as plt
2
2
import pandas as pd
3
3
import numpy as np
4
- import glob
5
4
from causal_testing .specification .causal_dag import CausalDAG
6
5
from causal_testing .specification .scenario import Scenario
7
6
from causal_testing .specification .variable import Input , Output
11
10
from causal_testing .testing .causal_test_outcome import Positive
12
11
from causal_testing .testing .intervention import Intervention
13
12
from causal_testing .testing .causal_test_engine import CausalTestEngine
14
- from causal_testing .testing .estimators import LinearRegressionEstimator , CausalForestEstimator
13
+ from causal_testing .testing .estimators import LinearRegressionEstimator
15
14
from matplotlib .pyplot import rcParams
16
15
17
- # Make the plots all fancy
16
+ # Make all graphs publication quality
18
17
plt .rcParams ["figure.figsize" ] = (8 , 8 )
19
18
rc_fonts = {
20
19
"font.size" : 8 ,
28
27
OBSERVATIONAL_DATA_PATH = "./data/10k_observational_data.csv"
29
28
30
29
31
- def concatenate_csvs_in_directory (directory_path , output_path ):
32
- """ Concatenate all csvs in a given directory, assuming all csvs share the same header. This will stack the csvs
33
- vertically and will not reset the index.
34
- """
35
- dfs = []
36
- for csv_name in glob .glob (directory_path ):
37
- dfs .append (pd .read_csv (csv_name , index_col = 0 ))
38
- full_df = pd .concat (dfs , ignore_index = True )
39
- full_df .to_csv (output_path )
40
-
41
-
42
- def CATE_on_csv (observational_data_path : str , simulate_counterfactuals : bool = False , verbose : bool = False ):
30
+ def doubling_beta_CATE_on_csv (observational_data_path : str , simulate_counterfactuals : bool = False ,
31
+ verbose : bool = False ):
43
32
""" Compute the CATE of increasing beta from 0.016 to 0.032 on cum_infections using the dataframe
44
33
loaded from the specified path. Additionally simulate the counterfactuals by repeating the analysis
45
34
after removing rows with beta==0.032.
46
35
47
36
:param observational_data_path: Path to csv containing observational data for analysis.
48
- :param simulate_counterfactuals: Whether or not to repeat analysis with counterfactuals.
49
- :param verbose: Whether or not to print verbose details (causal test results).
37
+ :param simulate_counterfactuals: Whether to repeat analysis with counterfactuals.
38
+ :param verbose: Whether to print verbose details (causal test results).
50
39
:return results_dict: A nested dictionary containing results (ate and confidence intervals)
51
40
for association, causation, and counterfactual (if completed).
52
41
"""
@@ -86,7 +75,7 @@ def CATE_on_csv(observational_data_path: str, simulate_counterfactuals: bool = F
86
75
print (f"Association:\n { association_test_result } " )
87
76
print (f"Causation:\n { causal_test_result } " )
88
77
89
- # Repeat causal inference after deleting all rows with treatment value
78
+ # Repeat causal inference after deleting all rows with treatment value to obtain counterfactual inferences
90
79
if simulate_counterfactuals :
91
80
counterfactual_past_execution_df = past_execution_df [past_execution_df ['beta' ] != 0.032 ]
92
81
counterfactual_linear_regression_estimator = LinearRegressionEstimator (('beta' ,), 0.032 , 0.016 ,
@@ -104,14 +93,14 @@ def CATE_on_csv(observational_data_path: str, simulate_counterfactuals: bool = F
104
93
return results_dict
105
94
106
95
107
- def manual_CATE (observational_data_path : str , simulate_counterfactual : bool = False , verbose : bool = False ):
96
+ def doubling_beta_CATEs (observational_data_path : str , simulate_counterfactual : bool = False , verbose : bool = False ):
108
97
""" Compute the CATE for the effect of doubling beta across simulations with different age demographics.
109
98
To compute the CATE, this method splits the observational data into high and low age data and computes the
110
99
ATE using this data and a linear regression model.
111
100
112
- Since this method already adjusts for age, adding age as
113
- an adjustment to the LR model will have no impact. However, adding contacts as an adjustment should reduce
114
- bias and reveal the average causal effect of doubling beta in simulations of a particular age demographic. """
101
+ Since this method already adjusts for age, adding age as an adjustment to the LR model will have no impact.
102
+ However, adding contacts as an adjustment should reduce bias and reveal the average causal effect of doubling beta
103
+ in simulations of a particular age demographic. """
115
104
116
105
# Create separate subplots for each more specific causal question
117
106
all_fig , all_axes = plt .subplots (1 , 1 , figsize = (4 , 3 ), squeeze = False )
@@ -121,8 +110,8 @@ def manual_CATE(observational_data_path: str, simulate_counterfactual: bool = Fa
121
110
# Apply CT to get the ATE over all executions
122
111
if verbose :
123
112
print ("Running causal tests for all data..." )
124
- all_data_results_dict = CATE_on_csv (observational_data_path , simulate_counterfactual , verbose = False )
125
- plot_manual_CATE_result (all_data_results_dict , "All Data" , all_fig , all_axes , row = 0 , col = 0 )
113
+ all_data_results_dict = doubling_beta_CATE_on_csv (observational_data_path , simulate_counterfactual , verbose = False )
114
+ plot_doubling_beta_CATEs (all_data_results_dict , "All Data" , all_fig , all_axes , row = 0 , col = 0 )
126
115
127
116
# Split data into age-specific strata
128
117
past_execution_df = pd .read_csv (observational_data_path )
@@ -141,13 +130,14 @@ def manual_CATE(observational_data_path: str, simulate_counterfactual: bool = Fa
141
130
"./data/bessemer/older_population.csv" ]
142
131
143
132
for col , separated_observational_data_path in enumerate (separated_observational_data_paths ):
144
- age_data_results_dict = CATE_on_csv (separated_observational_data_path , simulate_counterfactual , verbose = False )
133
+ age_data_results_dict = doubling_beta_CATE_on_csv (separated_observational_data_path , simulate_counterfactual ,
134
+ verbose = False )
145
135
age_stratified_df = pd .read_csv (separated_observational_data_path )
146
136
age_stratified_df_avg_age = round (age_stratified_df ["avg_age" ].mean (), 1 )
147
137
if verbose :
148
138
print (f"Running causal tests for data with average age of { age_stratified_df_avg_age } " )
149
- plot_manual_CATE_result (age_data_results_dict , f"Age={ age_stratified_df_avg_age } " , age_fig , age_axes , row = 0 ,
150
- col = col )
139
+ plot_doubling_beta_CATEs (age_data_results_dict , f"Age={ age_stratified_df_avg_age } " , age_fig , age_axes , row = 0 ,
140
+ col = col )
151
141
152
142
# Split df into contact-specific strata
153
143
min_contacts = np .floor (age_stratified_df ['contacts' ].min ())
@@ -167,16 +157,17 @@ def manual_CATE(observational_data_path: str, simulate_counterfactual: bool = Fa
167
157
168
158
# Compute the CATE for each age-contact group
169
159
for row , age_contact_data_path in enumerate (contact_observational_data_paths ):
170
- age_contact_data_results_dict = CATE_on_csv (age_contact_data_path , simulate_counterfactual , verbose = False )
160
+ age_contact_data_results_dict = doubling_beta_CATE_on_csv (age_contact_data_path , simulate_counterfactual ,
161
+ verbose = False )
171
162
age_contact_stratified_df = pd .read_csv (age_contact_data_path )
172
163
age_contact_stratified_df_avg_contacts = round (age_contact_stratified_df ["contacts" ].mean (), 1 )
173
164
if verbose :
174
165
print (f"Running causal tests for data with average age of { age_stratified_df_avg_age } and "
175
166
f"{ age_contact_stratified_df_avg_contacts } average household contacts." )
176
- plot_manual_CATE_result (age_contact_data_results_dict ,
177
- f"Age={ age_stratified_df_avg_age } "
178
- f"Contacts={ age_contact_stratified_df_avg_contacts } " ,
179
- age_contact_fig , age_contact_axes , row = row , col = col )
167
+ plot_doubling_beta_CATEs (age_contact_data_results_dict ,
168
+ f"Age={ age_stratified_df_avg_age } "
169
+ f"Contacts={ age_contact_stratified_df_avg_contacts } " ,
170
+ age_contact_fig , age_contact_axes , row = row , col = col )
180
171
181
172
# Save plots
182
173
if simulate_counterfactual :
@@ -230,32 +221,7 @@ def identification(observational_data_path):
230
221
return minimal_adjustment_set , causal_test_engine
231
222
232
223
233
- def causal_forest_CATE (observational_data_path ):
234
- _ , causal_test_engine = identification (observational_data_path )
235
- causal_forest_estimator = CausalForestEstimator (
236
- treatment = ('beta' ,),
237
- treatment_values = 0.032 ,
238
- control_values = 0.016 ,
239
- adjustment_set = {'avg_age' , 'contacts' },
240
- outcome = ('cum_infections' ,),
241
- effect_modifiers = {causal_test_engine .scenario .variables ['avg_age' ]})
242
- causal_forest_estimator_no_adjustment = CausalForestEstimator (
243
- treatment = ('beta' ,),
244
- treatment_values = 0.032 ,
245
- control_values = 0.016 ,
246
- adjustment_set = set (),
247
- outcome = ('cum_infections' ,),
248
- effect_modifiers = {causal_test_engine .scenario .variables ['avg_age' ]})
249
-
250
- # 10. Execute the test case and compare the results
251
- causal_test_result = causal_test_engine .execute_test (causal_forest_estimator , 'cate' )
252
- association_test_result = causal_test_engine .execute_test (causal_forest_estimator_no_adjustment , 'cate' )
253
- observational_data = pd .read_csv (observational_data_path )
254
- plot_causal_forest_result (causal_test_result , observational_data , "Causal Forest Adjusted for Age and Contacts." )
255
- plot_causal_forest_result (association_test_result , observational_data , "Causal Forest Without Adjustment." )
256
-
257
-
258
- def plot_manual_CATE_result (results_dict , title , figure = None , axes = None , row = None , col = None ):
224
+ def plot_doubling_beta_CATEs (results_dict , title , figure = None , axes = None , row = None , col = None ):
259
225
# Get the CATE as a percentage for association and causation
260
226
ate = results_dict ['causation' ]['ate' ]
261
227
association_ate = results_dict ['association' ]['ate' ]
@@ -283,10 +249,10 @@ def plot_manual_CATE_result(results_dict, title, figure=None, axes=None, row=Non
283
249
ys = [association_percentage_ate , percentage_ate ]
284
250
yerrs = [percentage_association_errs , percentage_causal_errs ]
285
251
xticks = ['Association' , 'Causation' ]
286
- print (f"Causal ATE: { percentage_ate } { percentage_causal_ate_cis } " )
287
- print (f"Causal executions: { len (causation_df )} " )
288
252
print (f"Association ATE: { association_percentage_ate } { percentage_association_ate_cis } " )
289
253
print (f"Association executions: { len (association_df )} " )
254
+ print (f"Causal ATE: { percentage_ate } { percentage_causal_ate_cis } " )
255
+ print (f"Causal executions: { len (causation_df )} " )
290
256
if 'counterfactual' in results_dict .keys ():
291
257
cf_ate = results_dict ['counterfactual' ]['ate' ]
292
258
cf_df = results_dict ['counterfactual' ]['df' ]
@@ -309,33 +275,5 @@ def plot_manual_CATE_result(results_dict, title, figure=None, axes=None, row=Non
309
275
figure .supylabel (r"\% Change in Cumulative Infections (ATE)" , fontsize = 10 )
310
276
311
277
312
- def plot_causal_forest_result (causal_forest_test_result , previous_data_df , title = None , filter_data_by_variant = False ):
313
- sorted_causal_forest_test_result = causal_forest_test_result .ate .sort_index ()
314
- no_avg_age_causal_forest_test_result = sorted_causal_forest_test_result .drop (columns = 'avg_age' )
315
- observational_df_with_results = previous_data_df .join (no_avg_age_causal_forest_test_result )
316
- observational_df_with_results ['percentage_increase' ] = \
317
- (observational_df_with_results ['cate' ] / observational_df_with_results ['cum_infections' ]) * 100
318
- fig , ax = plt .subplots ()
319
- if filter_data_by_variant :
320
- observational_df_with_results = observational_df_with_results .loc [observational_df_with_results ['variants' ]
321
- == 'beta' ]
322
- for location in observational_df_with_results .location .unique ():
323
- location_variant_df = observational_df_with_results .loc [observational_df_with_results ['location' ] == location ]
324
- xs = location_variant_df ['avg_age' ]
325
- ys = location_variant_df ['percentage_increase' ]
326
- ax .scatter (xs , ys , s = 1 , alpha = .3 , label = location )
327
- ax .set_ylabel ("% change in cumulative infections" )
328
- ax .set_xlabel ("Average age" )
329
- ax .set_title (title )
330
- ax .set_ylim (0 , 40 )
331
- box = ax .get_position ()
332
- ax .set_position ([box .x0 , box .y0 + box .height * 0.2 , box .width , box .height * 0.8 ])
333
- plt .legend (loc = 'upper center' , bbox_to_anchor = (0.5 , - 0.2 ), fancybox = True , ncol = 4 )
334
- plt .show ()
335
-
336
-
337
278
if __name__ == "__main__" :
338
- # concatenate_csvs_in_directory("./data/bessemer/custom_variants/thursday_31st_march/2k_executions/2k_data/*.csv",
339
- # "./data/10k_observational_data.csv")
340
- manual_CATE (OBSERVATIONAL_DATA_PATH , True , True )
341
- # causal_forest_CATE(OBSERVATIONAL_DATA_PATH)
279
+ doubling_beta_CATEs (OBSERVATIONAL_DATA_PATH , True , True )
0 commit comments