Skip to content

Commit 414386e

Browse files
authored
Merge pull request #88 from epiforecasts/81-add-option-for-real-time-hosp-data
Issue 81: Add option to fit to real-time hospital admissions data
2 parents 7dd8ef9 + 659d8f8 commit 414386e

28 files changed

+695
-120
lines changed

DESCRIPTION

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Depends:
2525
R (>= 4.0.0)
2626
Imports:
2727
cli,
28+
data.table,
2829
forecast,
2930
dplyr,
3031
tidyr,
@@ -34,7 +35,8 @@ Imports:
3435
wwinference,
3536
lubridate,
3637
fs,
37-
purrr,
38+
arrow,
39+
tidyselect,
3840
glue,
3941
scoringutils,
4042
zoo,

NAMESPACE

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@ export(quantile_metrics)
55
export(quiet)
66
export(sample_metrics)
77
export(save_csv)
8+
export(trajectories_to_quantiles)
89
importFrom(cli,cli_warn)
10+
importFrom(data.table,as.data.table)
11+
importFrom(data.table,setattr)
912
importFrom(dplyr,arrange)
1013
importFrom(dplyr,case_when)
1114
importFrom(dplyr,desc)
@@ -23,6 +26,7 @@ importFrom(forecast,auto.arima)
2326
importFrom(forecast,forecast)
2427
importFrom(fs,dir_create)
2528
importFrom(ggplot2,aes)
29+
importFrom(ggplot2,element_text)
2630
importFrom(ggplot2,facet_wrap)
2731
importFrom(ggplot2,geom_bar)
2832
importFrom(ggplot2,geom_line)
@@ -32,6 +36,7 @@ importFrom(ggplot2,geom_vline)
3236
importFrom(ggplot2,ggplot)
3337
importFrom(ggplot2,ggsave)
3438
importFrom(ggplot2,ggtitle)
39+
importFrom(ggplot2,theme)
3540
importFrom(ggplot2,theme_bw)
3641
importFrom(ggplot2,xlab)
3742
importFrom(ggplot2,ylab)

R/EDA_plots.R

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ get_plot_model_comparison <- function(
152152
}
153153
full_fp <- file.path(fig_fp, this_location)
154154
if (!file.exists(full_fp)) {
155-
dir_create(full_fp, recursive = TRUE, showWarnings = FALSE)
155+
dir_create(full_fp, recurse = TRUE)
156156
}
157157
ggsave(
158158
plot = p,
@@ -169,14 +169,15 @@ get_plot_model_comparison <- function(
169169
#'
170170
#' @param draws_w_data Data.frame of draws with data
171171
#' @param full_fp Directory to save
172-
#' @importFrom ggplot2 geom_vline
172+
#' @importFrom ggplot2 geom_vline theme element_text
173173
#' @returns ggplot object
174174
#' @autoglobal
175175
get_plot_draws_w_calib_data <- function(draws_w_data,
176176
full_fp) {
177177
loc <- unique(draws_w_data$location)
178178
include_ww <- unique(draws_w_data$include_ww)
179179
forecast_date <- unique(draws_w_data$forecast_date)
180+
hosp_data_real_time <- unique(draws_w_data$hosp_data_real_time)
180181
n_draws <- max(draws_w_data$draw, na.rm = TRUE)
181182
draws <- draws_w_data |> dplyr::filter(
182183
draw %in% sample.int(n_draws, size = min(100, n_draws))
@@ -194,14 +195,15 @@ get_plot_draws_w_calib_data <- function(draws_w_data,
194195
) +
195196
xlab("") +
196197
theme_bw() +
198+
theme(plot.title = element_text(size = 10)) +
197199
geom_vline(aes(xintercept = forecast_date), linetype = "dashed") +
198200
ylab("7-day rolling sum of hospital admissions") +
199-
ggtitle(glue("location: {loc}, include_ww: {include_ww}, forecast_date: {forecast_date}")) # nolint
201+
ggtitle(glue("location: {loc}, include_ww: {include_ww}, forecast_date: {forecast_date}, hosp data real time: {hosp_data_real_time}")) # nolint
200202
ggsave(
201203
plot = p,
202204
filename = file.path(
203205
full_fp,
204-
glue::glue("7d_hosp_draws_w_data_ww_{include_ww}.png")
206+
glue::glue("7d_hosp_draws_w_data_ww_{include_ww}_rt_{hosp_data_real_time}.png") # nolint
205207
)
206208
)
207209
return(p)
@@ -217,8 +219,8 @@ get_plot_draws_w_calib_data <- function(draws_w_data,
217219
#' @autoglobal
218220
get_bar_chart_overall_scores <- function(scores) {
219221
scores_summarised <- scores |>
220-
summarise_scores(by = c("model", "include_ww")) |>
221-
mutate(model_ww = glue::glue("{model}-{include_ww}"))
222+
summarise_scores(by = c("model", "include_ww", "hosp_data_real_time")) |>
223+
mutate(model_ww = glue::glue("{model}-{include_ww}-{hosp_data_real_time}"))
222224

223225
p <- ggplot(scores_summarised) +
224226
geom_bar(
@@ -232,5 +234,56 @@ get_bar_chart_overall_scores <- function(scores) {
232234
) +
233235
theme_bw() +
234236
ggtitle("Scores across all locations and forecast dates")
237+
238+
scores_by_loc <- scores |>
239+
summarise_scores(by = c(
240+
"model", "include_ww",
241+
"hosp_data_real_time", "forecast_date"
242+
)) |>
243+
mutate(model_ww = glue::glue("{model}-{include_ww}-{hosp_data_real_time}"))
244+
p <- ggplot(scores_by_loc) +
245+
geom_bar(
246+
aes(
247+
x = forecast_date,
248+
y = wis,
249+
fill = model_ww
250+
),
251+
stat = "identity",
252+
position = "dodge"
253+
) +
254+
theme_bw() +
255+
theme(legend.position = "bottom") +
256+
ggtitle("Scores across all locations by forecast dates")
257+
return(p)
258+
}
259+
260+
#' Get bar chart of the scores by forecast date
261+
#'
262+
#' @param scores Data.frame of scores from across locations and forecast dates
263+
#'
264+
#' @importFrom ggplot2 geom_bar
265+
#' @importFrom scoringutils summarise_scores
266+
#' @returns ggplot object
267+
#' @autoglobal
268+
get_plot_scores_by_date <- function(scores) {
269+
scores_by_loc <- scores |>
270+
summarise_scores(by = c(
271+
"model", "include_ww",
272+
"hosp_data_real_time", "forecast_date"
273+
)) |>
274+
mutate(model_ww = glue::glue("{model}-{include_ww}-{hosp_data_real_time}"))
275+
p <- ggplot(scores_by_loc) +
276+
geom_bar(
277+
aes(
278+
x = forecast_date,
279+
y = wis,
280+
fill = model_ww
281+
),
282+
stat = "identity",
283+
position = "dodge"
284+
) +
285+
theme_bw() +
286+
theme(legend.position = "bottom") +
287+
ggtitle("Scores across all locations by forecast dates")
235288
return(p)
236289
}

R/convert_rolling_sum_to_inc.R

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,6 @@ convert_rolling_sum_to_inc <- function(rolling_sums,
3838
stop("rolling_sums cannot be empty", call. = FALSE)
3939
}
4040

41-
if (anyNA(rolling_sums)) {
42-
warning("rolling_sums contains NA values. Function expects right-aligned rolling sums", call. = FALSE) # nolint
43-
}
44-
4541
# Handle initial values
4642
if (is.null(initial_values)) {
4743
# If no initial values provided, assume the first 6 days were zeros

R/convert_to_su_object.R

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#' Convert the scores to a scoringutils object
2+
#'
3+
#' @param scores_data Data.frame from Variant Nowcast Hub GitHub
4+
#' @importFrom data.table setattr as.data.table
5+
#' @importFrom rlang arg_match
6+
#' @importFrom dplyr rename select
7+
#' @returns scoringutils object
8+
convert_to_su_object <- function(scores_data) {
9+
scores2 <- data.table::as.data.table(scores_data)
10+
class(scores2) <- c("scores", class(scores2))
11+
scores_su <- data.table::setattr(
12+
scores2,
13+
"metrics",
14+
c(
15+
"wis", "underprediction", "overprediction", "dispersion",
16+
"bias", "interval_coverage_50", "interval_coverage_90",
17+
"ae_median"
18+
)
19+
)
20+
return(scores_su)
21+
}

R/fit_arima.R

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ fit_arima <- function(hosp_data_for_fit,
3131
date >= ymd(forecast_date),
3232
date <= ymd(forecast_date) + days(forecast_horizon)
3333
)
34-
34+
hosp_data_real_time <- unique(hosp_data_for_fit$hosp_data_real_time)
3535
auto_arima_model <- auto.arima(hosp_data_for_fit$updated_hosp_7d_count,
3636
seasonal = FALSE,
3737
stepwise = FALSE,
@@ -68,6 +68,9 @@ fit_arima <- function(hosp_data_for_fit,
6868
left_join(hosp_data_eval_forecast,
6969
by = "date"
7070
) |>
71-
mutate(forecast_date = ymd(forecast_date))
71+
mutate(
72+
forecast_date = ymd(forecast_date),
73+
hosp_data_real_time = hosp_data_real_time
74+
)
7275
return(forecast_df)
7376
}

R/fit_wwinference_wrapper.R

Lines changed: 69 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#' @param quantiles_to_save Vector of numerics indicating the quantiles
1515
#' @param ind_filepath Character string of the file path to save the outputs
1616
#' from each model run
17+
#' @param save_draws Boolean indicating whether or not to save the draws,
18+
#' default is FALSE.
1719
#'
1820
#' @returns Data.frame of the quantiles alongside the evaluation data.
1921
#' @autoglobal
@@ -34,9 +36,18 @@ fit_wwinference_wrapper <- function(
3436
calibration_time = 90,
3537
forecast_horizon = 28,
3638
quantiles_to_save = c(0.025, 0.1, 0.25, 0.5, 0.75, 0.9, 0.975),
37-
ind_filepath = file.path("output")) {
39+
ind_filepath = file.path("output"),
40+
save_draws = FALSE) {
3841
loc <- unique(count_data$state)
42+
if ((nrow(ww_data) == 0 || is.null(ww_data)) &&
43+
isTRUE(model_spec$include_ww)) {
44+
model_spec$include_ww <- FALSE
45+
flag_missing_ww <- TRUE
46+
} else {
47+
flag_missing_ww <- FALSE
48+
}
3949
include_ww <- model_spec$include_ww
50+
hosp_data_real_time <- unique(count_data$hosp_data_real_time)
4051
ww_fit_obj <- wwinference(
4152
ww_data = ww_data,
4253
count_data = count_data,
@@ -54,32 +65,37 @@ fit_wwinference_wrapper <- function(
5465
# Save plots
5566
full_fp <- file.path(ind_filepath, this_forecast_date, loc)
5667
if (!file.exists(file.path(full_fp))) {
57-
dir_create(full_fp, recursive = TRUE, showWarnings = FALSE)
68+
dir_create(full_fp, recurse = TRUE)
5869
}
5970
fig_fp <- file.path(full_fp, "figs")
6071
if (!file.exists(file.path(fig_fp))) {
61-
dir_create(fig_fp, recursive = TRUE, showWarnings = FALSE)
72+
dir_create(fig_fp, recurse = TRUE)
6273
}
6374

6475
plot_hosp_draws <- get_plot_forecasted_counts(
6576
draws = hosp_draws,
6677
forecast_date = this_forecast_date
67-
) + ggtitle(glue("{loc}, wastewater: {include_ww}"))
78+
) + ggtitle(glue("{loc}, wastewater: {include_ww}, hosp data real-time: {hosp_data_real_time}")) # nolint
6879

6980
ggsave(
7081
plot = plot_hosp_draws,
7182
filename = file.path(
7283
fig_fp,
73-
glue("hosp_draws_ww_{include_ww}.png")
84+
glue("hosp_draws_ww_{include_ww}_rt_{hosp_data_real_time}.png")
7485
)
7586
)
7687
ww_draws <- if (!is.null(ww_fit_obj$raw_input_data$input_ww_data)) {
7788
get_draws(ww_fit_obj, what = "predicted_ww")$predicted_ww
7889
} else {
7990
NULL
8091
}
92+
data_fp <- file.path(full_fp, "data")
93+
if (!file.exists(file.path(data_fp))) {
94+
dir_create(data_fp, recurse = TRUE)
95+
}
8196

8297
if (!is.null(ww_draws)) {
98+
# Plot
8399
plot_ww_draws <- get_plot_ww_conc(
84100
draws = ww_draws,
85101
forecast_date = this_forecast_date
@@ -91,6 +107,43 @@ fit_wwinference_wrapper <- function(
91107
"ww_draws.png"
92108
)
93109
)
110+
ww_data_obs <- select(
111+
ww_data,
112+
date, site, lab,
113+
log_genome_copies_per_ml, below_lod,
114+
log_lod, flag_as_ww_outlier
115+
)
116+
ww_metadata <- ww_data |>
117+
select(
118+
site, lab, site_pop,
119+
location_name, location_abbr,
120+
forecast_date, lab_site_name
121+
) |>
122+
distinct()
123+
124+
# Get and save quantiles
125+
ww_quantiles <- ww_draws |>
126+
trajectories_to_quantiles(
127+
quantiles = quantiles_to_save,
128+
timepoint_cols = "date",
129+
value_col = "pred_value",
130+
quantile_value_name = "predicted",
131+
quantile_level_name = "quantile_level",
132+
id_cols = c("site", "lab")
133+
) |>
134+
left_join(ww_data_obs,
135+
by = c("date", "site", "lab")
136+
) |>
137+
left_join(ww_metadata,
138+
by = c("site", "lab")
139+
)
140+
write_csv(
141+
ww_quantiles,
142+
file.path(
143+
data_fp,
144+
"ww_quantiles.csv"
145+
)
146+
)
94147
}
95148

96149
draws_w_data <- get_model_draws_w_data(
@@ -100,18 +153,18 @@ fit_wwinference_wrapper <- function(
100153
model = "wwinference",
101154
forecast_date = this_forecast_date,
102155
location = loc,
156+
hosp_data_real_time = hosp_data_real_time,
103157
eval_data = hosp_data_eval
104158
)
105-
data_fp <- file.path(full_fp, "data")
106-
if (!file.exists(file.path(data_fp))) {
107-
dir_create(data_fp, recursive = TRUE, showWarnings = FALSE)
159+
160+
if (isTRUE(save_draws)) {
161+
arrow::write_parquet(
162+
draws_w_data,
163+
file.path(data_fp, glue::glue(
164+
"hosp_draws_ww_{include_ww}_rt_{hosp_data_real_time}.parquet"
165+
))
166+
)
108167
}
109-
write_csv(
110-
draws_w_data,
111-
file.path(data_fp, glue::glue(
112-
"hosp_draws_ww_{include_ww}.csv"
113-
))
114-
)
115168
# Make a plot here with calibration and evaluation data and save it.
116169
get_plot_draws_w_calib_data(
117170
draws_w_data,
@@ -124,7 +177,8 @@ fit_wwinference_wrapper <- function(
124177
offset = 1,
125178
quantiles = TRUE,
126179
probs = quantiles_to_save
127-
)
180+
) |>
181+
mutate(flag_missing_ww = flag_missing_ww)
128182

129183
write_csv(
130184
hosp_quantiles,

0 commit comments

Comments
 (0)