@@ -320,3 +320,333 @@ get_scatterplot_scores <- function(scores) {
320320 geom_line(aes(x = hosp_only , y = hosp_only ), linetype = " dashed" )
321321 return (p )
322322}
323+
324+ # ' Get bar chart of WIS by location and forecast date
325+ # '
326+ # ' @param scores Data.frame of scores from across locations and forecast dates
327+ # ' @param n_locations Integer indicating number of locations to plot. Default is
328+ # ' 3. If NULL, all locations are plotted.
329+ # '
330+ # ' @returns ggplot object
331+ # ' @importFrom scoringutils summarise_scores
332+ # ' @importFrom ggplot2 ggplot aes geom_bar theme labs
333+ # ' @importFrom dplyr filter arrange slice_head
334+ # ' @export
335+ # ' @autoglobal
336+ get_bar_chart_scores_by_loc <- function (scores , n_locations = 3 ) {
337+ # Aggregate scores by location and forecast date
338+ scores_by_loc <- scores | >
339+ group_by(model , include_ww , hosp_data_real_time , forecast_date , location ) | >
340+ summarise(wis = mean(wis , na.rm = TRUE ), .groups = " drop" ) | >
341+ mutate(model_ww = glue :: glue(" {model}-{include_ww}-{hosp_data_real_time}" ))
342+
343+ # Select locations to plot
344+ if (! is.null(n_locations )) {
345+ # Get top n_locations by average WIS
346+ top_locations <- scores_by_loc | >
347+ group_by(location ) | >
348+ summarise(mean_wis = mean(wis , na.rm = TRUE )) | >
349+ arrange(mean_wis ) | >
350+ slice_head(n = n_locations ) | >
351+ pull(location )
352+
353+ scores_by_loc <- filter(scores_by_loc , location %in% top_locations )
354+ }
355+
356+ p <- ggplot(scores_by_loc ) +
357+ geom_bar(
358+ aes(
359+ x = forecast_date ,
360+ y = wis ,
361+ fill = model_ww
362+ ),
363+ stat = " identity" ,
364+ position = " dodge"
365+ ) +
366+ facet_wrap(~ location , scales = " free_y" ) +
367+ theme_bw() +
368+ theme(
369+ axis.text.x = element_text(angle = 45 , hjust = 1 ),
370+ legend.position = " bottom"
371+ ) +
372+ labs(
373+ x = " Forecast Date" ,
374+ y = " WIS" ,
375+ fill = " Model" ,
376+ title = " WIS by Location and Forecast Date"
377+ )
378+
379+ return (p )
380+ }
381+
382+ # ' Create combined forecast and WIS plot by location
383+ # '
384+ # ' Creates a two-row plot for each location: top row shows forecast time series
385+ # ' (hospital admissions with/without wastewater), bottom row shows WIS over time
386+ # ' (including ARIMA baseline). Locations are arranged in columns.
387+ # '
388+ # ' The function filters to every other forecast date for readability. It only
389+ # ' includes WIS for forecast dates that have corresponding forecast data files,
390+ # ' ensuring the time axes align correctly between the forecast and WIS plots.
391+ # '
392+ # ' @param output_path Path to the output folder containing forecast data
393+ # ' @param forecast_dates Character vector of forecast dates
394+ # ' @param scores Data.frame of scores from across locations and forecast dates
395+ # ' @param locations Character vector of location names. If NULL, three locations
396+ # ' are selected. Default is NULL.
397+ # ' @param forecast_horizon_to_plot Integer indicating number of days of horizon
398+ # ' to plot. Default is 28.
399+ # ' @param historical_data_to_plot Integer indicating number of days into the
400+ # ' past to plot. Default is 90.
401+ # ' @param scale_selected Character string indicating which scale to plot,
402+ # ' default is "natural"
403+ # ' @param save_path Optional path to save the figure. If NULL, figure is not
404+ # ' saved. Default is NULL.
405+ # ' @param n_forecast_dates Integer indicating number of forecast dates to show
406+ # ' in the WIS bar charts. Dates are selected spread across the time range.
407+ # ' Default is 3.
408+ # '
409+ # ' @returns A combined patchwork plot
410+ # ' @importFrom scoringutils summarise_scores
411+ # ' @importFrom ggplot2 ggplot aes geom_line geom_ribbon geom_point geom_bar
412+ # ' theme_bw theme element_text labs scale_color_manual scale_fill_manual
413+ # ' @importFrom dplyr filter mutate bind_rows group_by summarise arrange
414+ # ' @importFrom tidyr pivot_wider
415+ # ' @importFrom lubridate ymd
416+ # ' @importFrom patchwork wrap_plots plot_layout
417+ # ' @export
418+ # ' @autoglobal
419+ get_combined_forecast_wis_plot <- function (
420+ output_path ,
421+ forecast_dates ,
422+ scores ,
423+ locations = NULL ,
424+ forecast_horizon_to_plot = 28 ,
425+ historical_data_to_plot = 90 ,
426+ scale_selected = " natural" ,
427+ save_path = NULL ,
428+ n_forecast_dates = 3 ) {
429+ # Get available forecast dates from the directory that have actual data
430+
431+ forecasts_dir <- file.path(output_path , " individual_forecasts_all_runs" )
432+ available_forecast_dates <- list.dirs(
433+ forecasts_dir ,
434+ full.names = FALSE ,
435+ recursive = FALSE
436+ )
437+
438+ # Filter to dates that have location subdirectories with actual forecast data
439+ dates_with_data <- sapply(available_forecast_dates , function (d ) {
440+ date_path <- file.path(forecasts_dir , d )
441+ subdirs <- list.dirs(date_path , full.names = FALSE , recursive = FALSE )
442+ # Check if there are actual location subdirectories (German state names)
443+ # by looking for subdirs that don't match date patterns or error files
444+ return (any(grepl(" ^[A-Z]" , subdirs ) & ! grepl(" Error|^[0-9]{4}" , subdirs )))
445+ })
446+ available_forecast_dates <- available_forecast_dates [dates_with_data ]
447+
448+ # Filter to dates that exist in both the input and directory
449+ forecast_dates_available <-
450+ forecast_dates [forecast_dates %in% available_forecast_dates ]
451+
452+ if (length(forecast_dates_available ) == 0 ) {
453+ stop(" No matching forecast dates found in directory" , call. = FALSE )
454+ }
455+
456+ # Select n_forecast_dates spread across the time range for readability
457+ if (length(forecast_dates_available ) > n_forecast_dates ) {
458+ indices <- round(seq(1 , length(forecast_dates_available ),
459+ length.out = n_forecast_dates
460+ ))
461+ forecast_dates_filtered <- forecast_dates_available [indices ]
462+ } else {
463+ forecast_dates_filtered <- forecast_dates_available
464+ }
465+
466+ # Determine locations first if not specified
467+ if (is.null(locations )) {
468+ # Get available locations from the first forecast date
469+ first_forecast_path <- file.path(
470+ output_path ,
471+ " individual_forecasts_all_runs" ,
472+ forecast_dates_filtered [1 ]
473+ )
474+ if (! dir.exists(first_forecast_path )) {
475+ stop(" Forecast directory not found" , call. = FALSE )
476+ }
477+ available_locations <- list.dirs(
478+ first_forecast_path ,
479+ full.names = FALSE ,
480+ recursive = FALSE
481+ )
482+ locations <- sample(
483+ available_locations ,
484+ size = min(3 , length(available_locations ))
485+ )
486+ }
487+
488+ # Load hospital forecasts using helper function
489+ hosp_forecasts_list <- load_hospital_forecasts(
490+ output_path , forecast_dates_filtered , locations
491+ )
492+
493+ if (length(hosp_forecasts_list ) == 0 ) {
494+ stop(" No hospital forecast data found" , call. = FALSE )
495+ }
496+
497+ hosp_forecasts <- bind_rows(hosp_forecasts_list )
498+
499+ # Process hospital data using helper function
500+ hosp_processed <- process_hospital_data(
501+ hosp_forecasts ,
502+ forecast_horizon_to_plot ,
503+ historical_data_to_plot ,
504+ scale_selected
505+ )
506+
507+ forecasts_wide <- hosp_processed $ forecasts
508+ hosp_obs <- hosp_processed $ observations
509+
510+ # Filter to selected locations and add model labels
511+ forecasts_wide <- forecasts_wide | >
512+ filter(location %in% locations ) | >
513+ mutate(
514+ model_label = case_when(
515+ model_ww == " wwinference-TRUE" ~ " With wastewater data" ,
516+ model_ww == " wwinference-FALSE" ~ " Without wastewater data" ,
517+ TRUE ~ model_ww
518+ )
519+ )
520+ hosp_obs <- filter(hosp_obs , location %in% locations )
521+
522+ # Process scores - filter to locations and forecast dates with data
523+ scores_filtered <- scores | >
524+ filter(
525+ location %in% locations ,
526+ forecast_date %in% forecast_dates_filtered
527+ ) | >
528+ group_by(model , include_ww , hosp_data_real_time , forecast_date , location ) | >
529+ summarise(wis = mean(wis , na.rm = TRUE ), .groups = " drop" ) | >
530+ mutate(
531+ model_label = case_when(
532+ model == " arima_baseline" ~ " ARIMA baseline" ,
533+ model == " wwinference" & include_ww ~ " With wastewater data" ,
534+ model == " wwinference" & ! include_ww ~ " Without wastewater data" ,
535+ TRUE ~ glue :: glue(" {model}-{include_ww}" )
536+ ),
537+ forecast_date = ymd(forecast_date )
538+ )
539+
540+ # Define color palette matching the original plots
541+ model_colors <- c(
542+ " ARIMA baseline" = " #E57373" ,
543+ " With wastewater data" = " #64B5F6" ,
544+ " Without wastewater data" = " #81C784"
545+ )
546+
547+ # Create plots for each location
548+ plot_list <- list ()
549+
550+ for (loc in locations ) {
551+ # Forecast plot for this location
552+ loc_forecasts <- filter(forecasts_wide , location == loc )
553+ loc_obs <- filter(hosp_obs , location == loc )
554+
555+ p_forecast <- ggplot() +
556+ geom_line(
557+ data = loc_forecasts ,
558+ aes(
559+ x = date_parsed ,
560+ y = q_0.5 ,
561+ group = forecast_date_model_ww ,
562+ color = model_label
563+ )
564+ ) +
565+ geom_ribbon(
566+ data = loc_forecasts ,
567+ aes(
568+ x = date_parsed ,
569+ ymin = q_0.25 ,
570+ ymax = q_0.75 ,
571+ group = forecast_date_model_ww ,
572+ fill = model_label
573+ ),
574+ alpha = 0.3
575+ ) +
576+ geom_point(
577+ data = loc_obs ,
578+ aes(x = date_parsed , y = observed ),
579+ color = " black"
580+ ) +
581+ scale_color_manual(values = model_colors , guide = " none" ) +
582+ scale_fill_manual(values = model_colors , guide = " none" ) +
583+ theme_bw() +
584+ labs(
585+ y = " 7-day hospital admissions" ,
586+ title = loc
587+ ) +
588+ theme(
589+ axis.title.x = element_blank()
590+ )
591+
592+ # WIS plot for this location
593+ loc_scores <- filter(scores_filtered , location == loc )
594+
595+ p_wis <- ggplot(loc_scores ) +
596+ geom_bar(
597+ aes(
598+ x = forecast_date ,
599+ y = wis ,
600+ fill = model_label
601+ ),
602+ stat = " identity" ,
603+ position = " dodge"
604+ ) +
605+ scale_fill_manual(values = model_colors ) +
606+ theme_bw() +
607+ labs(
608+ x = " Forecast Date" ,
609+ y = " WIS" ,
610+ fill = " Model"
611+ ) +
612+ theme(
613+ axis.text.x = element_text(angle = 45 , hjust = 1 ),
614+ legend.position = " bottom"
615+ )
616+
617+ # Combine forecast and WIS plots vertically
618+ combined_loc <- wrap_plots(
619+ p_forecast ,
620+ p_wis ,
621+ ncol = 1 ,
622+ heights = c(2 , 1 )
623+ )
624+
625+ plot_list [[loc ]] <- combined_loc
626+ }
627+
628+ # Combine all location plots horizontally
629+ p_combined <- wrap_plots(
630+ plot_list ,
631+ ncol = length(locations ),
632+ guides = " collect"
633+ ) &
634+ theme(legend.position = " bottom" )
635+
636+ # Save if path provided
637+ if (! is.null(save_path )) {
638+ dir.create(save_path , recursive = TRUE , showWarnings = FALSE )
639+ date_range <- glue :: glue(" {min(forecast_dates)}_to_{max(forecast_dates)}" )
640+ ggsave(
641+ filename = file.path(
642+ save_path ,
643+ glue :: glue(" combined_forecast_wis_{date_range}.png" )
644+ ),
645+ plot = p_combined ,
646+ width = 4 * length(locations ),
647+ height = 10
648+ )
649+ }
650+
651+ return (p_combined )
652+ }
0 commit comments