Skip to content

Commit c026f6a

Browse files
Update example_beta.py
1 parent 425329b commit c026f6a

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

examples/covasim_/doubling_beta/example_beta.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,8 @@ def setup(observational_data):
276276

277277
def plot_doubling_beta_CATEs(results_dict, title, figure=None, axes=None, row=None, col=None):
278278
# Get the CATE as a percentage for association and causation
279-
ate = results_dict["causation"]["ate"]
280-
association_ate = results_dict["association"]["ate"]
279+
ate = results_dict["causation"]["ate"][0]
280+
association_ate = results_dict["association"]["ate"][0]
281281

282282
causation_df = results_dict["causation"]["df"]
283283
association_df = results_dict["association"]["df"]
@@ -288,11 +288,10 @@ def plot_doubling_beta_CATEs(results_dict, title, figure=None, axes=None, row=No
288288
# Get 95% confidence intervals for association and causation
289289
ate_cis = results_dict["causation"]["cis"]
290290
association_ate_cis = results_dict["association"]["cis"]
291-
percentage_causal_ate_cis = [round(((ci / causation_df["cum_infections"].mean()) * 100), 3) for ci in ate_cis]
291+
percentage_causal_ate_cis = [round(((ci[0] / causation_df["cum_infections"].mean()) * 100), 3) for ci in ate_cis]
292292
percentage_association_ate_cis = [
293-
round(((ci / association_df["cum_infections"].mean()) * 100), 3) for ci in association_ate_cis
293+
round(((ci[0] / association_df["cum_infections"].mean()) * 100), 3) for ci in association_ate_cis
294294
]
295-
296295
# Convert confidence intervals to errors for plotting
297296
percentage_causal_errs = [
298297
percentage_ate - percentage_causal_ate_cis[0],
@@ -314,9 +313,9 @@ def plot_doubling_beta_CATEs(results_dict, title, figure=None, axes=None, row=No
314313
if "counterfactual" in results_dict.keys():
315314
cf_ate = results_dict["counterfactual"]["ate"]
316315
cf_df = results_dict["counterfactual"]["df"]
317-
percentage_cf_ate = round((cf_ate / cf_df["cum_infections"].mean()) * 100, 3)
316+
percentage_cf_ate = round((cf_ate[0] / cf_df["cum_infections"].mean()) * 100, 3)
318317
cf_ate_cis = results_dict["counterfactual"]["cis"]
319-
percentage_cf_cis = [round(((ci / cf_df["cum_infections"].mean()) * 100), 3) for ci in cf_ate_cis]
318+
percentage_cf_cis = [round(((ci[0] / cf_df["cum_infections"].mean()) * 100), 3) for ci in cf_ate_cis]
320319
percentage_cf_errs = [percentage_cf_ate - percentage_cf_cis[0], percentage_cf_cis[1] - percentage_cf_ate]
321320
xs = [0.5, 1.5, 2.5]
322321
ys = [association_percentage_ate, percentage_ate, percentage_cf_ate]

0 commit comments

Comments
 (0)