Skip to content

Commit e3fb560

Browse files
figure 4 first draft
1 parent 422dbd6 commit e3fb560

File tree

5 files changed

+433
-0
lines changed

5 files changed

+433
-0
lines changed

NAMESPACE

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ export(calculate_ww_metadata_table)
44
export(convert_rolling_sum_to_inc)
55
export(create_hospital_plot)
66
export(create_ww_plot)
7+
export(get_bar_chart_scores_by_loc)
8+
export(get_combined_forecast_wis_plot)
79
export(load_hospital_forecasts)
810
export(load_later_ww_obs)
911
export(load_ww_forecasts)
@@ -35,6 +37,7 @@ importFrom(dplyr,n_distinct)
3537
importFrom(dplyr,pull)
3638
importFrom(dplyr,rename)
3739
importFrom(dplyr,select)
40+
importFrom(dplyr,slice_head)
3841
importFrom(dplyr,summarise)
3942
importFrom(dplyr,ungroup)
4043
importFrom(forecast,auto.arima)
@@ -57,6 +60,8 @@ importFrom(ggplot2,ggplot)
5760
importFrom(ggplot2,ggsave)
5861
importFrom(ggplot2,ggtitle)
5962
importFrom(ggplot2,labs)
63+
importFrom(ggplot2,scale_color_manual)
64+
importFrom(ggplot2,scale_fill_manual)
6065
importFrom(ggplot2,theme)
6166
importFrom(ggplot2,theme_bw)
6267
importFrom(ggplot2,vars)

R/EDA_plots.R

Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,3 +320,333 @@ get_scatterplot_scores <- function(scores) {
320320
geom_line(aes(x = hosp_only, y = hosp_only), linetype = "dashed")
321321
return(p)
322322
}
323+
324+
#' Get bar chart of WIS by location and forecast date
325+
#'
326+
#' @param scores Data.frame of scores from across locations and forecast dates
327+
#' @param n_locations Integer indicating number of locations to plot. Default is
328+
#' 3. If NULL, all locations are plotted.
329+
#'
330+
#' @returns ggplot object
331+
#' @importFrom scoringutils summarise_scores
332+
#' @importFrom ggplot2 ggplot aes geom_bar theme labs
333+
#' @importFrom dplyr filter arrange slice_head
334+
#' @export
335+
#' @autoglobal
336+
get_bar_chart_scores_by_loc <- function(scores, n_locations = 3) {
337+
# Aggregate scores by location and forecast date
338+
scores_by_loc <- scores |>
339+
group_by(model, include_ww, hosp_data_real_time, forecast_date, location) |>
340+
summarise(wis = mean(wis, na.rm = TRUE), .groups = "drop") |>
341+
mutate(model_ww = glue::glue("{model}-{include_ww}-{hosp_data_real_time}"))
342+
343+
# Select locations to plot
344+
if (!is.null(n_locations)) {
345+
# Get top n_locations by average WIS
346+
top_locations <- scores_by_loc |>
347+
group_by(location) |>
348+
summarise(mean_wis = mean(wis, na.rm = TRUE)) |>
349+
arrange(mean_wis) |>
350+
slice_head(n = n_locations) |>
351+
pull(location)
352+
353+
scores_by_loc <- filter(scores_by_loc, location %in% top_locations)
354+
}
355+
356+
p <- ggplot(scores_by_loc) +
357+
geom_bar(
358+
aes(
359+
x = forecast_date,
360+
y = wis,
361+
fill = model_ww
362+
),
363+
stat = "identity",
364+
position = "dodge"
365+
) +
366+
facet_wrap(~location, scales = "free_y") +
367+
theme_bw() +
368+
theme(
369+
axis.text.x = element_text(angle = 45, hjust = 1),
370+
legend.position = "bottom"
371+
) +
372+
labs(
373+
x = "Forecast Date",
374+
y = "WIS",
375+
fill = "Model",
376+
title = "WIS by Location and Forecast Date"
377+
)
378+
379+
return(p)
380+
}
381+
382+
#' Create combined forecast and WIS plot by location
383+
#'
384+
#' Creates a two-row plot for each location: top row shows forecast time series
385+
#' (hospital admissions with/without wastewater), bottom row shows WIS over time
386+
#' (including ARIMA baseline). Locations are arranged in columns.
387+
#'
388+
#' The function filters to every other forecast date for readability. It only
389+
#' includes WIS for forecast dates that have corresponding forecast data files,
390+
#' ensuring the time axes align correctly between the forecast and WIS plots.
391+
#'
392+
#' @param output_path Path to the output folder containing forecast data
393+
#' @param forecast_dates Character vector of forecast dates
394+
#' @param scores Data.frame of scores from across locations and forecast dates
395+
#' @param locations Character vector of location names. If NULL, three locations
396+
#' are selected. Default is NULL.
397+
#' @param forecast_horizon_to_plot Integer indicating number of days of horizon
398+
#' to plot. Default is 28.
399+
#' @param historical_data_to_plot Integer indicating number of days into the
400+
#' past to plot. Default is 90.
401+
#' @param scale_selected Character string indicating which scale to plot,
402+
#' default is "natural"
403+
#' @param save_path Optional path to save the figure. If NULL, figure is not
404+
#' saved. Default is NULL.
405+
#' @param n_forecast_dates Integer indicating number of forecast dates to show
406+
#' in the WIS bar charts. Dates are selected spread across the time range.
407+
#' Default is 3.
408+
#'
409+
#' @returns A combined patchwork plot
410+
#' @importFrom scoringutils summarise_scores
411+
#' @importFrom ggplot2 ggplot aes geom_line geom_ribbon geom_point geom_bar
412+
#' theme_bw theme element_text labs scale_color_manual scale_fill_manual
413+
#' @importFrom dplyr filter mutate bind_rows group_by summarise arrange
414+
#' @importFrom tidyr pivot_wider
415+
#' @importFrom lubridate ymd
416+
#' @importFrom patchwork wrap_plots plot_layout
417+
#' @export
418+
#' @autoglobal
419+
get_combined_forecast_wis_plot <- function(
420+
output_path,
421+
forecast_dates,
422+
scores,
423+
locations = NULL,
424+
forecast_horizon_to_plot = 28,
425+
historical_data_to_plot = 90,
426+
scale_selected = "natural",
427+
save_path = NULL,
428+
n_forecast_dates = 3) {
429+
# Get available forecast dates from the directory that have actual data
430+
431+
forecasts_dir <- file.path(output_path, "individual_forecasts_all_runs")
432+
available_forecast_dates <- list.dirs(
433+
forecasts_dir,
434+
full.names = FALSE,
435+
recursive = FALSE
436+
)
437+
438+
# Filter to dates that have location subdirectories with actual forecast data
439+
dates_with_data <- sapply(available_forecast_dates, function(d) {
440+
date_path <- file.path(forecasts_dir, d)
441+
subdirs <- list.dirs(date_path, full.names = FALSE, recursive = FALSE)
442+
# Check if there are actual location subdirectories (German state names)
443+
# by looking for subdirs that don't match date patterns or error files
444+
return(any(grepl("^[A-Z]", subdirs) & !grepl("Error|^[0-9]{4}", subdirs)))
445+
})
446+
available_forecast_dates <- available_forecast_dates[dates_with_data]
447+
448+
# Filter to dates that exist in both the input and directory
449+
forecast_dates_available <-
450+
forecast_dates[forecast_dates %in% available_forecast_dates]
451+
452+
if (length(forecast_dates_available) == 0) {
453+
stop("No matching forecast dates found in directory", call. = FALSE)
454+
}
455+
456+
# Select n_forecast_dates spread across the time range for readability
457+
if (length(forecast_dates_available) > n_forecast_dates) {
458+
indices <- round(seq(1, length(forecast_dates_available),
459+
length.out = n_forecast_dates
460+
))
461+
forecast_dates_filtered <- forecast_dates_available[indices]
462+
} else {
463+
forecast_dates_filtered <- forecast_dates_available
464+
}
465+
466+
# Determine locations first if not specified
467+
if (is.null(locations)) {
468+
# Get available locations from the first forecast date
469+
first_forecast_path <- file.path(
470+
output_path,
471+
"individual_forecasts_all_runs",
472+
forecast_dates_filtered[1]
473+
)
474+
if (!dir.exists(first_forecast_path)) {
475+
stop("Forecast directory not found", call. = FALSE)
476+
}
477+
available_locations <- list.dirs(
478+
first_forecast_path,
479+
full.names = FALSE,
480+
recursive = FALSE
481+
)
482+
locations <- sample(
483+
available_locations,
484+
size = min(3, length(available_locations))
485+
)
486+
}
487+
488+
# Load hospital forecasts using helper function
489+
hosp_forecasts_list <- load_hospital_forecasts(
490+
output_path, forecast_dates_filtered, locations
491+
)
492+
493+
if (length(hosp_forecasts_list) == 0) {
494+
stop("No hospital forecast data found", call. = FALSE)
495+
}
496+
497+
hosp_forecasts <- bind_rows(hosp_forecasts_list)
498+
499+
# Process hospital data using helper function
500+
hosp_processed <- process_hospital_data(
501+
hosp_forecasts,
502+
forecast_horizon_to_plot,
503+
historical_data_to_plot,
504+
scale_selected
505+
)
506+
507+
forecasts_wide <- hosp_processed$forecasts
508+
hosp_obs <- hosp_processed$observations
509+
510+
# Filter to selected locations and add model labels
511+
forecasts_wide <- forecasts_wide |>
512+
filter(location %in% locations) |>
513+
mutate(
514+
model_label = case_when(
515+
model_ww == "wwinference-TRUE" ~ "With wastewater data",
516+
model_ww == "wwinference-FALSE" ~ "Without wastewater data",
517+
TRUE ~ model_ww
518+
)
519+
)
520+
hosp_obs <- filter(hosp_obs, location %in% locations)
521+
522+
# Process scores - filter to locations and forecast dates with data
523+
scores_filtered <- scores |>
524+
filter(
525+
location %in% locations,
526+
forecast_date %in% forecast_dates_filtered
527+
) |>
528+
group_by(model, include_ww, hosp_data_real_time, forecast_date, location) |>
529+
summarise(wis = mean(wis, na.rm = TRUE), .groups = "drop") |>
530+
mutate(
531+
model_label = case_when(
532+
model == "arima_baseline" ~ "ARIMA baseline",
533+
model == "wwinference" & include_ww ~ "With wastewater data",
534+
model == "wwinference" & !include_ww ~ "Without wastewater data",
535+
TRUE ~ glue::glue("{model}-{include_ww}")
536+
),
537+
forecast_date = ymd(forecast_date)
538+
)
539+
540+
# Define color palette matching the original plots
541+
model_colors <- c(
542+
"ARIMA baseline" = "#E57373",
543+
"With wastewater data" = "#64B5F6",
544+
"Without wastewater data" = "#81C784"
545+
)
546+
547+
# Create plots for each location
548+
plot_list <- list()
549+
550+
for (loc in locations) {
551+
# Forecast plot for this location
552+
loc_forecasts <- filter(forecasts_wide, location == loc)
553+
loc_obs <- filter(hosp_obs, location == loc)
554+
555+
p_forecast <- ggplot() +
556+
geom_line(
557+
data = loc_forecasts,
558+
aes(
559+
x = date_parsed,
560+
y = q_0.5,
561+
group = forecast_date_model_ww,
562+
color = model_label
563+
)
564+
) +
565+
geom_ribbon(
566+
data = loc_forecasts,
567+
aes(
568+
x = date_parsed,
569+
ymin = q_0.25,
570+
ymax = q_0.75,
571+
group = forecast_date_model_ww,
572+
fill = model_label
573+
),
574+
alpha = 0.3
575+
) +
576+
geom_point(
577+
data = loc_obs,
578+
aes(x = date_parsed, y = observed),
579+
color = "black"
580+
) +
581+
scale_color_manual(values = model_colors, guide = "none") +
582+
scale_fill_manual(values = model_colors, guide = "none") +
583+
theme_bw() +
584+
labs(
585+
y = "7-day hospital admissions",
586+
title = loc
587+
) +
588+
theme(
589+
axis.title.x = element_blank()
590+
)
591+
592+
# WIS plot for this location
593+
loc_scores <- filter(scores_filtered, location == loc)
594+
595+
p_wis <- ggplot(loc_scores) +
596+
geom_bar(
597+
aes(
598+
x = forecast_date,
599+
y = wis,
600+
fill = model_label
601+
),
602+
stat = "identity",
603+
position = "dodge"
604+
) +
605+
scale_fill_manual(values = model_colors) +
606+
theme_bw() +
607+
labs(
608+
x = "Forecast Date",
609+
y = "WIS",
610+
fill = "Model"
611+
) +
612+
theme(
613+
axis.text.x = element_text(angle = 45, hjust = 1),
614+
legend.position = "bottom"
615+
)
616+
617+
# Combine forecast and WIS plots vertically
618+
combined_loc <- wrap_plots(
619+
p_forecast,
620+
p_wis,
621+
ncol = 1,
622+
heights = c(2, 1)
623+
)
624+
625+
plot_list[[loc]] <- combined_loc
626+
}
627+
628+
# Combine all location plots horizontally
629+
p_combined <- wrap_plots(
630+
plot_list,
631+
ncol = length(locations),
632+
guides = "collect"
633+
) &
634+
theme(legend.position = "bottom")
635+
636+
# Save if path provided
637+
if (!is.null(save_path)) {
638+
dir.create(save_path, recursive = TRUE, showWarnings = FALSE)
639+
date_range <- glue::glue("{min(forecast_dates)}_to_{max(forecast_dates)}")
640+
ggsave(
641+
filename = file.path(
642+
save_path,
643+
glue::glue("combined_forecast_wis_{date_range}.png")
644+
),
645+
plot = p_combined,
646+
width = 4 * length(locations),
647+
height = 10
648+
)
649+
}
650+
651+
return(p_combined)
652+
}

R/globals.R

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,27 @@ utils::globalVariables(c(
4141
"FALSE", # <get_scatterplot_scores>
4242
"hosp_only", # <get_scatterplot_scores>
4343
"ww_plus_hosp", # <get_scatterplot_scores>
44+
"model", # <get_bar_chart_scores_by_loc>
45+
"include_ww", # <get_bar_chart_scores_by_loc>
46+
"hosp_data_real_time", # <get_bar_chart_scores_by_loc>
47+
"forecast_date", # <get_bar_chart_scores_by_loc>
48+
"location", # <get_bar_chart_scores_by_loc>
49+
"wis", # <get_bar_chart_scores_by_loc>
50+
"mean_wis", # <get_bar_chart_scores_by_loc>
51+
"model_ww", # <get_bar_chart_scores_by_loc>
52+
"location", # <get_combined_forecast_wis_plot>
53+
"forecast_date", # <get_combined_forecast_wis_plot>
54+
"model", # <get_combined_forecast_wis_plot>
55+
"include_ww", # <get_combined_forecast_wis_plot>
56+
"hosp_data_real_time", # <get_combined_forecast_wis_plot>
57+
"wis", # <get_combined_forecast_wis_plot>
58+
"date_parsed", # <get_combined_forecast_wis_plot>
59+
"q_0.5", # <get_combined_forecast_wis_plot>
60+
"forecast_date_model_ww", # <get_combined_forecast_wis_plot>
61+
"model_label", # <get_combined_forecast_wis_plot>
62+
"q_0.25", # <get_combined_forecast_wis_plot>
63+
"q_0.75", # <get_combined_forecast_wis_plot>
64+
"observed", # <get_combined_forecast_wis_plot>
4465
"site", # <fit_wwinference_wrapper>
4566
"lab", # <fit_wwinference_wrapper>
4667
"log_genome_copies_per_ml", # <fit_wwinference_wrapper>

0 commit comments

Comments
 (0)