@@ -276,8 +276,8 @@ def setup(observational_data):
276
276
277
277
def plot_doubling_beta_CATEs (results_dict , title , figure = None , axes = None , row = None , col = None ):
278
278
# Get the CATE as a percentage for association and causation
279
- ate = results_dict ["causation" ]["ate" ]
280
- association_ate = results_dict ["association" ]["ate" ]
279
+ ate = results_dict ["causation" ]["ate" ][ 0 ]
280
+ association_ate = results_dict ["association" ]["ate" ][ 0 ]
281
281
282
282
causation_df = results_dict ["causation" ]["df" ]
283
283
association_df = results_dict ["association" ]["df" ]
@@ -288,11 +288,10 @@ def plot_doubling_beta_CATEs(results_dict, title, figure=None, axes=None, row=No
288
288
# Get 95% confidence intervals for association and causation
289
289
ate_cis = results_dict ["causation" ]["cis" ]
290
290
association_ate_cis = results_dict ["association" ]["cis" ]
291
- percentage_causal_ate_cis = [round (((ci / causation_df ["cum_infections" ].mean ()) * 100 ), 3 ) for ci in ate_cis ]
291
+ percentage_causal_ate_cis = [round (((ci [ 0 ] / causation_df ["cum_infections" ].mean ()) * 100 ), 3 ) for ci in ate_cis ]
292
292
percentage_association_ate_cis = [
293
- round (((ci / association_df ["cum_infections" ].mean ()) * 100 ), 3 ) for ci in association_ate_cis
293
+ round (((ci [ 0 ] / association_df ["cum_infections" ].mean ()) * 100 ), 3 ) for ci in association_ate_cis
294
294
]
295
-
296
295
# Convert confidence intervals to errors for plotting
297
296
percentage_causal_errs = [
298
297
percentage_ate - percentage_causal_ate_cis [0 ],
@@ -314,9 +313,9 @@ def plot_doubling_beta_CATEs(results_dict, title, figure=None, axes=None, row=No
314
313
if "counterfactual" in results_dict .keys ():
315
314
cf_ate = results_dict ["counterfactual" ]["ate" ]
316
315
cf_df = results_dict ["counterfactual" ]["df" ]
317
- percentage_cf_ate = round ((cf_ate / cf_df ["cum_infections" ].mean ()) * 100 , 3 )
316
+ percentage_cf_ate = round ((cf_ate [ 0 ] / cf_df ["cum_infections" ].mean ()) * 100 , 3 )
318
317
cf_ate_cis = results_dict ["counterfactual" ]["cis" ]
319
- percentage_cf_cis = [round (((ci / cf_df ["cum_infections" ].mean ()) * 100 ), 3 ) for ci in cf_ate_cis ]
318
+ percentage_cf_cis = [round (((ci [ 0 ] / cf_df ["cum_infections" ].mean ()) * 100 ), 3 ) for ci in cf_ate_cis ]
320
319
percentage_cf_errs = [percentage_cf_ate - percentage_cf_cis [0 ], percentage_cf_cis [1 ] - percentage_cf_ate ]
321
320
xs = [0.5 , 1.5 , 2.5 ]
322
321
ys = [association_percentage_ate , percentage_ate , percentage_cf_ate ]
0 commit comments