Skip to content

Commit 3494651

Browse files
authored
Merge pull request #52 from epiforecasts/feature/add-coverage-panel-issue-51
Add prediction interval coverage panel to zoom_25A figure
2 parents 141b2fe + 8aa1946 commit 3494651

15 files changed

+291
-63
lines changed

.lintr

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ linters: all_linters(
66
indentation_linter = NULL,
77
object_name_linter = NULL,
88
object_usage_linter = NULL,
9-
return_linter(return_style = "explicit"),
9+
return_linter = NULL,
1010
cyclocomp_linter(25L)
1111
)
1212
exclusions: list(

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ importFrom(rlang,sym)
4141
importFrom(scoringutils,add_relative_skill)
4242
importFrom(scoringutils,as_forecast_quantile)
4343
importFrom(scoringutils,bias_quantile)
44+
importFrom(scoringutils,get_coverage)
4445
importFrom(scoringutils,score)
4546
importFrom(scoringutils,summarise_scores)
4647
importFrom(tidyr,pivot_longer)

R/compute_bias.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
#' @importFrom dplyr filter mutate
99
#' @importFrom tidyr pivot_longer
1010
#' @autoglobal
11-
prepare_data_for_scoring_25A <- function(df_mult_nowcasts,
12-
clade = "25A",
13-
horizon_range = c(-31, 10)) {
11+
prepare_data_for_scoring <- function(df_mult_nowcasts,
12+
clade = "25A",
13+
horizon_range = c(-31, 10)) {
1414
# Filter for specified clade and horizon range
1515
df_filtered <- df_mult_nowcasts |>
1616
filter(clade == !!clade) |>
@@ -45,7 +45,7 @@ prepare_data_for_scoring_25A <- function(df_mult_nowcasts,
4545
#' @importFrom dplyr filter group_by summarise
4646
#' @importFrom scoringutils as_forecast_quantile score bias_quantile
4747
#' @autoglobal
48-
compute_bias_25A <- function(df_prepared, locs, nowcast_dates) {
48+
compute_bias <- function(df_prepared, locs, nowcast_dates) {
4949
# Filter to specific locations and nowcast dates
5050
df_to_score <- filter(
5151
df_prepared,

R/compute_coverage.R

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#' Compute prediction interval coverage using scoringutils
2+
#'
3+
#' @param df_prepared Prepared data in long format
4+
#' @param locs Vector of location codes to include
5+
#' @param nowcast_dates Vector of nowcast dates to include
6+
#' @param intervals Numeric vector of interval ranges to include
7+
#' (default: c(50, 95)
8+
#'
9+
#' @returns Data frame with coverage for 50% and 95% intervals by model,
10+
#' location, and nowcast_date
11+
#' @importFrom dplyr filter group_by summarise mutate
12+
#' @importFrom scoringutils as_forecast_quantile score get_coverage
13+
#' @autoglobal
14+
compute_coverage <- function(df_prepared, locs, nowcast_dates,
15+
intervals = c(50, 95)) {
16+
# Filter to specific locations and nowcast dates
17+
df_to_score <- filter(
18+
df_prepared,
19+
location %in% locs,
20+
nowcast_date %in% nowcast_dates
21+
)
22+
23+
# Convert to scoringutils forecast object
24+
forecast_obj <- scoringutils::as_forecast_quantile(
25+
df_to_score,
26+
forecast_unit = c(
27+
"model_id", "location", "nowcast_date",
28+
"target_date", "clade"
29+
),
30+
observed = "observed",
31+
predicted = "predicted",
32+
quantile_level = "quantile_level"
33+
)
34+
all_coverage <- scoringutils::get_coverage(
35+
forecast_obj,
36+
by = c(
37+
"location", "nowcast_date", "target_date",
38+
"model_id", "clade"
39+
)
40+
)
41+
coverage <- filter(
42+
all_coverage,
43+
interval_range %in% c(intervals)
44+
)
45+
46+
return(coverage)
47+
}

R/fig_zoom_clade_mult_nowcasts.R

Lines changed: 121 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@ get_plot_model_preds_mult <- function(model_preds_mult_nowcasts,
1616
mutate(horizon = as.integer(target_date - nowcast_date)) |>
1717
filter(horizon <= max(horizon_to_plot), horizon >= min(horizon_to_plot))
1818

19-
weekly_obs_data <- daily_to_weekly(final_eval_data) |>
19+
# Use daily observations instead of weekly
20+
daily_obs_data <- final_eval_data |>
2021
filter(location %in% unique(df_filt$location))
21-
total_seq <- weekly_obs_data |>
22+
total_seq <- daily_obs_data |>
2223
group_by(date, location) |>
2324
summarise(n_seq = sum(sequences))
24-
weekly_obs <- left_join(weekly_obs_data, total_seq) |>
25+
daily_obs <- left_join(daily_obs_data, total_seq) |>
2526
filter(clades_modeled == clade_to_zoom)
2627

2728
plot_comps <- plot_components()
@@ -38,7 +39,8 @@ get_plot_model_preds_mult <- function(model_preds_mult_nowcasts,
3839
ymax = q_0.75, fill = model_id,
3940
group = nowcast_date
4041
),
41-
alpha = 0.2
42+
alpha = 0.2,
43+
show.legend = FALSE
4244
) +
4345
geom_ribbon(
4446
aes(
@@ -47,41 +49,47 @@ get_plot_model_preds_mult <- function(model_preds_mult_nowcasts,
4749
ymax = q_0.975, fill = model_id,
4850
group = nowcast_date
4951
),
50-
alpha = 0.1
52+
alpha = 0.1,
53+
show.legend = FALSE
5154
) +
5255
geom_point(
53-
data = weekly_obs,
54-
aes(x = date, y = sequences / n_seq),
55-
color = "#CAB2D6"
56-
) +
57-
geom_line(
58-
data = weekly_obs,
59-
aes(x = date, y = sequences / n_seq),
60-
color = "#CAB2D6"
56+
data = daily_obs,
57+
aes(x = date, y = sequences / n_seq, fill = "25A"),
58+
color = "#CAB2D6",
59+
shape = 21,
60+
size = 0.8
6161
) +
6262
facet_grid(vars(model_id), vars(location)) +
63+
coord_cartesian(ylim = c(0, 1)) +
6364
get_plot_theme(dates = TRUE) +
6465
scale_color_manual(
6566
name = "Model",
6667
values = plot_comps$model_colors
6768
) +
6869
scale_fill_manual(
69-
name = "Model",
70-
values = plot_comps$model_colors
70+
name = "Clade",
71+
values = c(plot_comps$model_colors, "25A" = "#CAB2D6"),
72+
breaks = "25A"
7173
) +
7274
xlab("") +
7375
ylab("Model predictions across nowcast dates") +
7476
guides(
7577
color = "none",
76-
fill = "none"
78+
fill = guide_legend(
79+
title.position = "top",
80+
title.hjust = 0.5,
81+
nrow = 1
82+
)
7783
) +
7884
scale_x_date(
7985
limits = c(min(df_filt$target_date), max(df_filt$target_date)),
8086
date_breaks = "1 week",
8187
date_labels = "%d %b %Y"
8288
) +
8389
ggtitle("25A emergence") +
84-
theme(axis.text.x = element_blank())
90+
theme(
91+
plot.margin = margin(5.5, 5.5, 5.5, 40, "pt") # Increase left margin
92+
)
8593

8694
return(p)
8795
}
@@ -136,27 +144,25 @@ get_plot_scores_by_date <- function(scores,
136144
aes(yintercept = energy_score, color = model),
137145
linetype = "dashed"
138146
) +
139-
facet_wrap(~location, scales = "free_y") +
147+
facet_wrap(~location, ncol = 3, scales = "free_y") +
140148
get_plot_theme(dates = TRUE) +
141149
scale_color_manual(
142150
name = "Model",
143151
values = plot_comps$model_colors
144152
) +
145153
xlab("") +
146154
guides(
147-
color = guide_legend(
148-
title.position = "top",
149-
title.hjust = 0.5,
150-
nrow = 1
151-
)
155+
color = "none"
152156
) +
153-
ylab("Average energy score") +
157+
ylab("Average\nenergy score") +
154158
scale_x_date(
155159
limits = date_range,
156160
date_breaks = "1 week",
157161
date_labels = "%d %b %Y"
158162
) +
159-
theme(axis.text.x = element_blank())
163+
theme(
164+
plot.margin = margin(5.5, 5.5, 5.5, 40, "pt") # Increase left margin
165+
)
160166
return(p)
161167
}
162168

@@ -166,13 +172,20 @@ get_plot_scores_by_date <- function(scores,
166172
#' @param locs Vector of character strings of locations
167173
#' @param nowcast_dates Set of nowcast dates to include
168174
#' @param date_range Range of dates to plot
175+
#' @param plot_name name of plot
176+
#' @param output_fp filepath directory
169177
#'
170178
#' @returns ggplot
171179
#' @autoglobal
172180
get_plot_bias_by_date <- function(bias_data,
173181
locs,
174182
nowcast_dates,
175-
date_range) {
183+
date_range,
184+
plot_name = "bias_over_time_25A",
185+
output_fp = file.path(
186+
"output", "figs",
187+
"zoom_25A", "supp"
188+
)) {
176189
# Calculate average bias across all nowcast dates for reference lines
177190
bias_avg <- filter(
178191
bias_data,
@@ -207,7 +220,8 @@ get_plot_bias_by_date <- function(bias_data,
207220
x = nowcast_date, y = bias,
208221
color = model
209222
)) +
210-
facet_wrap(~location) +
223+
facet_wrap(~location, ncol = 3) +
224+
coord_cartesian(ylim = c(-1, 1)) +
211225
get_plot_theme(dates = TRUE) +
212226
scale_color_manual(
213227
name = "Model",
@@ -221,6 +235,82 @@ get_plot_bias_by_date <- function(bias_data,
221235
date_breaks = "1 week",
222236
date_labels = "%d %b %Y"
223237
)
238+
dir_create(output_fp, recurse = TRUE)
239+
ggsave(file.path(output_fp, glue::glue("{plot_name}.png")),
240+
plot = p,
241+
width = 8,
242+
height = 6
243+
)
244+
245+
return(p)
246+
}
247+
248+
#' Get a plot of prediction interval coverage summarized across nowcast dates
249+
#'
250+
#' @param coverage Data.frame of coverage scores with interval_range
251+
#' @param locs Vector of character strings of locations
252+
#'
253+
#' @returns ggplot
254+
#' @autoglobal
255+
get_plot_coverage_overall <- function(coverage,
256+
locs) {
257+
# Filter and summarize coverage across nowcast dates
258+
coverage_summary <- coverage |>
259+
group_by(model_id, location, interval_range) |>
260+
summarise(empirical_coverage = sum(interval_coverage) / n()) |>
261+
pivot_wider(
262+
names_from = interval_range,
263+
values_from = empirical_coverage
264+
) |>
265+
mutate(`95` = `95` - `50`) |>
266+
pivot_longer(
267+
cols = c(`50`, `95`),
268+
names_to = "interval_range",
269+
values_to = "empirical_coverage"
270+
) |>
271+
mutate(
272+
interval_label = paste0(interval_range, "%"),
273+
interval_label = factor(interval_label, levels = c("95%", "50%"))
274+
)
275+
276+
277+
plot_comps <- plot_components()
278+
279+
p <- ggplot(coverage_summary) +
280+
# Add horizontal reference lines for nominal coverage
281+
# Create stacked bar chart
282+
geom_bar(
283+
aes(
284+
x = model_id, y = empirical_coverage, fill = model_id,
285+
alpha = interval_label
286+
),
287+
stat = "identity",
288+
position = "stack",
289+
width = 0.7
290+
) +
291+
geom_hline(yintercept = 0.5, linetype = "dashed") +
292+
geom_hline(yintercept = 0.95, linetype = "dashed") +
293+
facet_wrap(~location, ncol = 3) +
294+
get_plot_theme(dates = FALSE) +
295+
theme(axis.text.x = element_blank()) +
296+
scale_fill_manual(
297+
name = "Model",
298+
values = plot_comps$model_colors
299+
) +
300+
scale_alpha_manual(
301+
name = "Interval coverage",
302+
values = plot_comps$pred_int_alpha
303+
) +
304+
guides(
305+
fill = guide_legend(
306+
title.position = "top",
307+
title.hjust = 0.5,
308+
nrow = 3
309+
)
310+
) +
311+
xlab("Model") +
312+
ylab("Empirical\ncoverage") +
313+
scale_y_continuous(limits = c(0, 1), breaks = seq(0, 1, 0.2))
224314

225315
return(p)
226316
}
@@ -229,15 +319,15 @@ get_plot_bias_by_date <- function(bias_data,
229319
#'
230320
#' @param grid Model predictions plot
231321
#' @param scores Energy scores plot
232-
#' @param bias Bias scores plot
322+
#' @param coverage Prediction interval coverage plot
233323
#' @param plot_name name of plot
234324
#' @param output_fp filepath directory
235325
#'
236326
#' @returns patchwork
237327
#' @autoglobal
238328
get_fig_zoom_25A <- function(grid,
239329
scores,
240-
bias,
330+
coverage,
241331
plot_name,
242332
output_fp = file.path(
243333
"output", "figs",
@@ -253,10 +343,10 @@ get_fig_zoom_25A <- function(grid,
253343

254344
fig_zoom <- grid +
255345
scores +
256-
bias +
346+
coverage +
257347
plot_layout(
258348
design = fig_layout,
259-
axes = "collect",
349+
axes = "collect_x",
260350
guides = "collect"
261351
) +
262352
plot_annotation(

R/globals.R

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,19 @@ utils::globalVariables(c(
4545
"clades_modeled", # <get_clean_variant_data_ns>
4646
"location_code", # <get_clean_variant_data_ns>
4747
"sequences", # <get_clean_variant_data_ns>
48-
"target_date", # <prepare_data_for_scoring_25A>
49-
"nowcast_date", # <prepare_data_for_scoring_25A>
50-
"horizon", # <prepare_data_for_scoring_25A>
51-
"quantile_level", # <prepare_data_for_scoring_25A>
52-
"sequences", # <prepare_data_for_scoring_25A>
53-
"n_seq", # <prepare_data_for_scoring_25A>
54-
"location", # <compute_bias_25A>
55-
"nowcast_date", # <compute_bias_25A>
56-
"model_id", # <compute_bias_25A>
57-
"bias", # <compute_bias_25A>
48+
"target_date", # <prepare_data_for_scoring>
49+
"nowcast_date", # <prepare_data_for_scoring>
50+
"horizon", # <prepare_data_for_scoring>
51+
"quantile_level", # <prepare_data_for_scoring>
52+
"sequences", # <prepare_data_for_scoring>
53+
"n_seq", # <prepare_data_for_scoring>
54+
"location", # <compute_bias>
55+
"nowcast_date", # <compute_bias>
56+
"model_id", # <compute_bias>
57+
"bias", # <compute_bias>
58+
"location", # <compute_coverage>
59+
"nowcast_date", # <compute_coverage>
60+
"interval_range", # <compute_coverage>
5861
"location", # <extract_nowcasts>
5962
"nowcast_date", # <extract_nowcasts>
6063
"location", # <get_oracle_output>
@@ -135,5 +138,13 @@ utils::globalVariables(c(
135138
"model", # <get_plot_bias_by_date>
136139
"bias", # <get_plot_bias_by_date>
137140
"avg_bias", # <get_plot_bias_by_date>
141+
"model_id", # <get_plot_coverage_overall>
142+
"location", # <get_plot_coverage_overall>
143+
"interval_range", # <get_plot_coverage_overall>
144+
"interval_coverage", # <get_plot_coverage_overall>
145+
"empirical_coverage", # <get_plot_coverage_overall>
146+
"95", # <get_plot_coverage_overall>
147+
"50", # <get_plot_coverage_overall>
148+
"interval_label", # <get_plot_coverage_overall>
138149
NULL
139150
))

0 commit comments

Comments
 (0)