Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions R/EDA_plots.R
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,13 @@ get_bar_chart_overall_scores <- function(scores) {
#' @returns ggplot object
#' @autoglobal
get_plot_scores_by_date <- function(scores) {
scores_by_loc <- scores |>
scores_by_date <- scores |>
summarise_scores(by = c(
"model", "include_ww",
"hosp_data_real_time", "forecast_date"
)) |>
mutate(model_ww = glue::glue("{model}-{include_ww}-{hosp_data_real_time}"))
p <- ggplot(scores_by_loc) +
p <- ggplot(scores_by_date) +
geom_bar(
aes(
x = forecast_date,
Expand All @@ -287,3 +287,36 @@ get_plot_scores_by_date <- function(scores) {
ggtitle("Scores across all locations by forecast dates")
return(p)
}

#' Get scatterplot of scores by forecast date and location
#'
#' @param scores Data.frame of scores from across locations and forecast dates
#'
#' @importFrom ggplot2 geom_bar
#' @importFrom scoringutils summarise_scores
#' @returns ggplot object
#' @autoglobal
get_scatterplot_scores <- function(scores) {
scores_by_forecast <- scores |>
summarise_scores(by = c(
"model", "include_ww",
"hosp_data_real_time", "forecast_date",
"location"
)) |>
filter(model == "wwinference") |>
pivot_wider(
names_from = include_ww,
values_from = wis,
id_cols = c(forecast_date, location)
) |>
rename(
ww_plus_hosp = TRUE,
hosp_only = FALSE
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added brackets here to avoid this error: Error in rename(): ! Can't rename columns with TRUE. ✖ TRUE must be numeric or character, not TRUE.

)


p <- ggplot(scores_by_forecast) +
geom_point(aes(x = hosp_only, y = ww_plus_hosp)) +
geom_line(aes(x = hosp_only, y = hosp_only), linetype = "dashed")
return(p)
}
18 changes: 9 additions & 9 deletions R/convert_to_su_object.R
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
#' Convert the scores to a scoringutils object
#'
#' @param scores_data Data.frame from Variant Nowcast Hub GitHub
#' @param scores_raw Data.frame of scores
#' @importFrom data.table setattr as.data.table
#' @importFrom rlang arg_match
#' @importFrom dplyr rename select
#' @returns scoringutils object
convert_to_su_object <- function(scores_data) {
scores2 <- data.table::as.data.table(scores_data)
class(scores2) <- c("scores", class(scores2))
#' @autoglobal
convert_to_su_object <- function(scores_raw) {
scores <- data.table::as.data.table(scores_raw)
class(scores) <- c("scores", class(scores))
scores_su <- data.table::setattr(
scores2,
scores,
"metrics",
c(
"wis", "underprediction", "overprediction", "dispersion",
"bias", "interval_coverage_50", "interval_coverage_90",
"ae_median"
"wis", "overprediction", "underprediction",
"dispersion", "bias", "interval_coverage_50",
"interval_coverage_90", "ae_median"
)
)
return(scores_su)
Expand Down
7 changes: 7 additions & 0 deletions R/globals.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ utils::globalVariables(c(
"forecast_date", # <get_plot_scores_by_date>
"wis", # <get_plot_scores_by_date>
"model_ww", # <get_plot_scores_by_date>
"model", # <get_scatterplot_scores>
"include_ww", # <get_scatterplot_scores>
"wis", # <get_scatterplot_scores>
"forecast_date", # <get_scatterplot_scores>
"location", # <get_scatterplot_scores>
"hosp_only", # <get_scatterplot_scores>
"ww_plus_hosp", # <get_scatterplot_scores>
"site", # <fit_wwinference_wrapper>
"lab", # <fit_wwinference_wrapper>
"log_genome_copies_per_ml", # <fit_wwinference_wrapper>
Expand Down
9 changes: 9 additions & 0 deletions R/prep_scores_to_model.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
prep_scores_to_model <- function(scores_long,
ww_metadata) {
# Pivot scores from long to wide

# Join wastewater metadata

# Placeholder
return(NULL)
}
Comment on lines +1 to +9
Copy link

@coderabbitai coderabbitai bot Jan 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Implement the placeholder function before merging.

This function is called by scores_to_model in targets/run_gam_targets.R but currently returns NULL, which will break the downstream pipeline. The PR description mentions remaining tasks, but this function should be implemented or the calling target should be removed before merging.

Would you like me to generate a basic implementation template for pivoting scores and joining metadata, or should this target be temporarily removed from the pipeline?

🤖 Prompt for AI Agents
In @R/prep_scores_to_model.R around lines 1 - 9, prep_scores_to_model currently
returns NULL and breaks downstream targets; implement it to pivot scores_long
from long to wide (e.g., each sample/feature combination becomes a row with
score columns) and then join ww_metadata by the appropriate key (sample id or
site id used in scores_long) before returning a cleaned data.frame/tibble;
update the function prep_scores_to_model to accept scores_long and ww_metadata,
use pivot_wider (or spread) on the score name column and values column, perform
a left_join with ww_metadata using the matching id column present in both,
ensure column names are consistent with the caller scores_to_model in
targets/run_gam_targets.R, and return the resulting joined table instead of
NULL.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am going to ignore this for a subsequent PR

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kaitejohnson, understood! Would you like me to open a follow-up issue to track the implementation of prep_scores_to_model?

71 changes: 36 additions & 35 deletions _targets.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# Targets script for generating forecasts and performing immediate
# post-processing (quantiling and scoring)
# Targets script for analysing forecasts and scores
# This pipeline assumes that the `output/` folder contains:
# - overall_data_all_runs/scores.csv: a single file with all of the scores
# for all forecasts for the 3 models (wwinference with and without ww and
# baseline ARIMA)
# - individual_forecasts_all_runs/{forecast_date}/{location}/data: hospital
# admissions quantiles for wwinference with and without wastewater, R(t)
# estimates for the location with and without wastewater, and predicted
# quantiled wastewater concentrations


# The pipeline can be run using `tar_make()`

Expand Down Expand Up @@ -28,9 +36,6 @@ functions <- list.files(here("R"), full.names = TRUE)
walk(functions, source)
rm("functions")

n_workers <- as.integer(floor(future::availableCores() / 4))
plan(multisession, workers = n_workers)

# load target modules
targets <- list.files(here("targets"), full.names = TRUE)
targets <- grep("*\\.R", targets, value = TRUE)
Expand Down Expand Up @@ -59,45 +64,41 @@ tar_option_set(
error = "continue"
)

## Set up the date:location:model:ww+/-:right-trunc+/- permutations
set_up <- list(
create_permutations_targets
# Analysis config
analysis_config <- list(
# Full set of dates and locations and models for which the model
# was run for
create_permutations_targets,
# Set of dates and locations to focus on in example figures +
# specifications of any post-processing model outputs
analysis_config_targets
)


## Iterate over all permutations. For each:
# - extract the necessary data
# - pre-process the data based on the model's requirements
# - fit the model
# - extract posterior hospital admissions (calibration and forecast)
# - score the forecasts using CRPS and extract
# - quantile the calibration and forecasted admissions and extract
# - extract input data (hosp and/or ww)
# - extract model diagnostics

# Current set up: uses the `scenarios` tibble to do dynamic branching within
# each function via pattern = map(ind_data_created, scenarios)
load_data <- list(
# Load data for each location/forecast date combination
load_data_targets,
load_baseline_data_targets
)
# Wastewater metadata
get_metadata <- list(
get_metadata_targets
)
fit_models <- list(
fit_model_targets,
fit_baseline_model_targets

# Secondary outputs
secondary_outputs <- list(
# GAM meta-model on scores ()
run_gam_targets
# compute coverage metrics (?)
)

scoring <- list(
scoring_targets
# Figures
plot_targets <- list(
analysis_EDA_plot_targets
# Fig 1: visual comparison for a single forecast date
# Fig 2: visual comparison + scores across forecast dates
# Fig 3: overall, by horizon, by location, by forecast date
# by location and forecast date
# Fig 4: Model-based evaluation results
)

list(
set_up,
load_data,
analysis_config,
get_metadata,
fit_models,
scoring
secondary_outputs,
plot_targets
)
100 changes: 100 additions & 0 deletions _targets_model_run.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Targets script for generating forecasts and performing immediate
# post-processing (quantiling and scoring)

# The pipeline can be run using `tar_make(script = "_targets_model_run.R")`

library(targets)
library(jsonlite)
library(httr)
library(tarchetypes)
library(wwinference)
library(dplyr)
library(ggplot2)
library(readr)
library(here)
library(purrr)
library(lubridate)
library(tidyr)
library(glue)
library(fs)
library(rlang)
library(scoringutils)
library(forecast)
library(future)
library(future.callr)

# load functions
functions <- list.files(here("R"), full.names = TRUE)
walk(functions, source)
rm("functions")

n_workers <- as.integer(floor(future::availableCores() / 4))
plan(multisession, workers = n_workers)

# load target modules
targets <- list.files(here("targets"), full.names = TRUE)
targets <- grep("*\\.R", targets, value = TRUE)
purrr::walk(targets, source)

tar_option_set(
packages = c(
"wwinference",
"tibble",
"dplyr",
"ggplot2",
"readr",
"lubridate",
"tidyr",
"glue",
"forecast",
"jsonlite",
"httr"
),
workspace_on_error = TRUE,
storage = "worker",
retrieval = "worker",
memory = "transient",
garbage_collection = TRUE,
format = "parquet", # default storage format
error = "continue"
)

## Set up the date:location:model:ww+/-:right-trunc+/- permutations
set_up <- list(
create_permutations_targets
)


## Iterate over all permutations. For each:
# - extract the necessary data
# - pre-process the data based on the model's requirements
# - fit the model
# - extract posterior hospital admissions (calibration and forecast)
# - score the forecasts using CRPS and extract
# - quantile the calibration and forecasted admissions and extract
# - extract input data (hosp and/or ww)
# - extract model diagnostics

# Current set up: uses the `scenarios` tibble to do dynamic branching within
# each function via pattern = map(ind_data_created, scenarios)
load_data <- list(
# Load data for each location/forecast date combination
load_data_targets,
load_baseline_data_targets
)

fit_models <- list(
fit_model_targets,
fit_baseline_model_targets
)

scoring <- list(
scoring_targets
)

list(
set_up,
load_data,
fit_models,
scoring
)
4 changes: 2 additions & 2 deletions man/convert_to_su_object.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions man/get_scatterplot_scores.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 33 additions & 0 deletions scratch/explore_results.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
scores <- read_csv(file.path("output", "overall_data", "scores.csv"))

scores_overall <- scores |>
group_by(model, include_ww) |>
summarise(wis = mean(wis)) |>
Copy link

@coderabbitai coderabbitai bot Jan 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add na.rm = TRUE to the mean() call.

If the wis column contains missing values, the aggregation will return NA for all groups, causing downstream plotting to fail.

🔎 Proposed fix
-  summarise(wis = mean(wis)) |>
+  summarise(wis = mean(wis, na.rm = TRUE)) |>
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
summarise(wis = mean(wis)) |>
summarise(wis = mean(wis, na.rm = TRUE)) |>
🤖 Prompt for AI Agents
In @scratch/explore_results.R at line 5, The summarise call computes wis with
mean(wis) which will return NA if any wis values are missing; update the
aggregation in the summarise step (the summarise function that defines wis) to
call mean(wis, na.rm = TRUE) so missing values are ignored and downstream
plotting won't fail.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alos ignoring stuff in scratch

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kaitejohnson, understood! Since scratch/ contains exploratory code, it makes sense to skip detailed reviews there.

If you'd like to prevent automated reviews on the scratch/ directory in future PRs, you can configure CodeRabbit to exclude it by adding a path_filters configuration in your .coderabbit.yaml file. For example:

reviews:
  path_filters:
    - "!scratch/**"

This will exclude all files under scratch/ from automated reviews.

ungroup() |>
mutate(model_ww = glue::glue("{model}-{include_ww}"))


ggplot(scores_overall) +
geom_bar(aes(x = model_ww, y = wis, fill = model_ww),
stat = "identity", position = "stack"
)

rwis <- scores_overall$wis / scores_overall$wis[2]

scores_by_date <- scores |>
group_by(model, include_ww, forecast_date) |>
summarise(wis = mean(wis)) |>
ungroup() |>
mutate(model_ww = glue::glue("{model}-{include_ww}"))

ggplot(scores_by_date) +
geom_line(aes(x = forecast_date, y = wis, color = model_ww))

scores_by_loc <- scores |>
group_by(model, include_ww, location) |>
summarise(wis = mean(wis)) |>
ungroup() |>
mutate(model_ww = glue::glue("{model}-{include_ww}"))

ggplot(scores_by_loc) +
geom_point(aes(x = location, y = wis, color = model_ww))
10 changes: 10 additions & 0 deletions targets/analysis_EDA_plot_targets.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
analysis_EDA_plot_targets <- list(
tar_target(
name = plot_scores_by_date,
command = get_plot_scores_by_date(scores)
),
tar_target(
name = scatterplot_scores,
command = get_scatterplot_scores(scores)
)
)
17 changes: 17 additions & 0 deletions targets/analysis_config_targets.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
analysis_config_targets <- list(
tar_target(
ww_data_post,
get_ww_as_of_forecast_date(
forecast_date = scenarios$forecast_date,
location_name = scenarios$location_name,
location_abbr = scenarios$location_abbr,
calibration_period = calibration_period_wwinference,
path_to_lod_vals = path_to_lod_vals
),
pattern = map(scenarios)
),
tar_target(
name = scores_fp,
command = file.path("output", "overall_data_all_runs", "scores.csv")
)
)
Loading
Loading