Skip to content

Commit f41e648

Browse files
authored
Merge pull request #168 from mrc-ide/sero_fitting
0.6.7 sero fitting in
2 parents 736dcbe + 3a323c2 commit f41e648

File tree

7 files changed

+180
-6
lines changed

7 files changed

+180
-6
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Package: squire
22
Type: Package
33
Title: SEIR transmission model of COVID-19
4-
Version: 0.6.6
4+
Version: 0.6.7
55
Authors@R: c(
66
person("OJ", "Watson", email = "o.watson15@imperial.ac.uk", role = c("aut", "cre")),
77
person("Patrick", "Walker", email = "patrick.walker06@imperial.ac.uk", role = c("aut")),

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ importFrom(odin,odin)
5656
importFrom(rlang,.data)
5757
importFrom(stats,cor)
5858
importFrom(stats,cov)
59+
importFrom(stats,dbinom)
5960
importFrom(stats,dnbinom)
6061
importFrom(stats,median)
6162
importFrom(stats,plogis)

NEWS.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# squire 0.6.7
2+
3+
* `pmcmc` can now be used to fit to serology data (deterministic model only)
4+
by passing `sero_df` and `sero_det` as names list elements of `pars_obs`
5+
16
# squire 0.6.6
27

38
* `projections` can be used now for `nimue` models

R/particle.R

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,7 @@ scale_log_weights <- function(log_weights) {
650650
#'
651651
#' @return Results from particle filter
652652
#'
653+
#' @importFrom stats dbinom
653654
run_deterministic_comparison <- function(data,
654655
squire_model,
655656
model_params,
@@ -700,7 +701,7 @@ run_deterministic_comparison <- function(data,
700701
Ds[Ds < 0] <- 0
701702
deaths <- data$deaths[-1]
702703

703-
# calculate ll
704+
# calculate ll for deaths
704705
if (obs_params$treated_deaths_only) {
705706

706707
Ds_heathcare <- diff(rowSums(out[,index$D_get]))
@@ -713,14 +714,57 @@ run_deterministic_comparison <- function(data,
713714

714715
}
715716

717+
# calculate ll for the seroprevalence
718+
lls <- 0
719+
if("sero_df" %in% obs_params && "sero_det" %in% obs_params) {
720+
721+
sero_df <- obs_params$sero_df
722+
sero_det <- obs_params$sero_det
723+
724+
# were there actually seroprevalence data points to compare against
725+
if(nrow(sero_df) > 0) {
726+
727+
sero_at_date <- function(date, symptoms, det, dates, N) {
728+
729+
di <- which(dates == date)
730+
if(length(di) > 0) {
731+
to_sum <- tail(symptoms[seq_len(di)], length(det))
732+
min(sum(rev(to_sum)*head(det, length(to_sum)), na.rm=TRUE)/N, 0.99)
733+
} else {
734+
0
735+
}
736+
737+
}
738+
739+
# get symptom incidence
740+
symptoms <- rowSums(out[,index$E2]) * model_params$gamma_E
741+
742+
# dates of incidence, pop size and dates of sero surveys
743+
dates <- data$date[[1]] + seq_len(nrow(out)) - 1L
744+
N <- sum(model_params$population)
745+
sero_dates <- list(sero_df$date_end, sero_df$date_start, sero_df$date_start + as.integer((sero_df$date_end - sero_df$date_start)/2))
746+
unq_sero_dates <- unique(c(sero_df$date_end, sero_df$date_start, sero_df$date_start + as.integer((sero_df$date_end - sero_df$date_start)/2)))
747+
det <- obs_params$sero_det
748+
749+
# estimate model seroprev
750+
sero_model <- vapply(unq_sero_dates, sero_at_date, numeric(1), symptoms, det, dates, N)
751+
sero_model_mat <- do.call(cbind,lapply(sero_dates, function(x) {sero_model[match(x, unq_sero_dates)]}))
752+
753+
# likelihood of model obvs
754+
lls <- rowMeans(dbinom(sero_df$sero_pos, sero_df$samples, sero_model_mat, log = TRUE))
755+
756+
}
757+
758+
}
759+
716760
# format the out object
717761
date <- data$date[[1]] + seq_len(nrow(out)) - 1L
718762
rownames(out) <- as.character(date)
719763
attr(out, "date") <- date
720764

721765
# format similar to particle_filter nomenclature
722766
pf_results <- list()
723-
pf_results$log_likelihood <- sum(ll)
767+
pf_results$log_likelihood <- sum(ll) + sum(lls)
724768

725769
# single returns final state
726770
if (save_history) {

R/pmcmc.R

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,8 +293,8 @@ pmcmc <- function(data,
293293
}
294294
assert_logical(unlist(pars_discrete))
295295
assert_list(pars_obs)
296-
assert_eq(names(pars_obs), c("phi_cases", "k_cases", "phi_death", "k_death", "exp_noise"))
297-
assert_numeric(unlist(pars_obs))
296+
assert_in(c("phi_cases", "k_cases", "phi_death", "k_death", "exp_noise"), names(pars_obs))
297+
assert_numeric(unlist(pars_obs[c("phi_cases", "k_cases", "phi_death", "k_death", "exp_noise")]))
298298

299299
# mcmc items
300300
assert_pos_int(n_mcmc)
@@ -1315,6 +1315,12 @@ calc_loglikelihood <- function(pars, data, squire_model, model_params,
13151315
R0 <- pars[["R0"]]
13161316
start_date <- pars[["start_date"]]
13171317

1318+
# reporting fraction par if in pars list
1319+
if("rf" %in% names(pars)) {
1320+
assert_numeric(pars[["rf"]])
1321+
pars_obs$phi_death <- pars[["rf"]]
1322+
}
1323+
13181324
#----------------..
13191325
# more assertions
13201326
#----------------..

R/pmcmc_object.R

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,17 @@ plot_pmcmc_sample <- function(x, what = "deaths") {
4949

5050
base_plot <- plot(x, "deaths", ci = FALSE, replicates = TRUE, x_var = "date",
5151
date_0 = max(x$pmcmc_results$inputs$data$date))
52+
53+
if ("rf" %in% names(x$replicate_parameters)) {
54+
rf <- mean(x$replicate_parameters$rf)
55+
} else {
56+
rf <- x$pmcmc_results$inputs$pars_obs$phi_death
57+
}
58+
5259
base_plot <- base_plot +
5360
ggplot2::geom_line(ggplot2::aes(y=.data$ymin, x=as.Date(.data$date)), quants, linetype="dashed") +
5461
ggplot2::geom_line(ggplot2::aes(y=.data$ymax, x=as.Date(.data$date)), quants, linetype="dashed") +
55-
ggplot2::geom_point(ggplot2::aes(y=.data$deaths/x$pmcmc_results$inputs$pars_obs$phi_death,
62+
ggplot2::geom_point(ggplot2::aes(y=.data$deaths/rf,
5663
x=as.Date(.data$date)), x$pmcmc_results$inputs$data) +
5764
ggplot2::theme(legend.position = "top")
5865

tests/testthat/test-pmcmc.R

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1643,3 +1643,114 @@ test_that("pmcmc restarting covariance/scaling", {
16431643

16441644

16451645
})
1646+
1647+
1648+
1649+
#------------------------------------------------
1650+
test_that("sero fitting works", {
1651+
1652+
set.seed(12)
1653+
1654+
Sys.setenv("SQUIRE_PARALLEL_DEBUG" = "TRUE")
1655+
data <- read.csv(squire_file("extdata/example.csv"),stringsAsFactors = FALSE)
1656+
interventions <- read.csv(squire_file("extdata/example_intervention.csv"))
1657+
int_unique <- interventions_unique(interventions)
1658+
reporting_fraction = 1
1659+
country = "Algeria"
1660+
pars_init = list('start_date' = as.Date("2020-02-07"),
1661+
'R0' = 3,
1662+
'Meff' = 2,
1663+
"rf" = 0.25) # correct rf for the data
1664+
pars_min = list('start_date' = as.Date("2020-02-01"),
1665+
'R0' = 1e-10,
1666+
'Meff' = 0.1,
1667+
"rf" = 0.1)
1668+
pars_max = list('start_date' = as.Date("2020-02-20"),
1669+
'R0' = 5,
1670+
'Meff' = 5,
1671+
"rf" = 1)
1672+
pars_discrete = list('start_date' = TRUE,
1673+
'R0' = FALSE,
1674+
'Meff' = FALSE,
1675+
'rf' = FALSE)
1676+
pars_obs = list(phi_cases = 0.1,
1677+
k_cases = 2,
1678+
phi_death = 1,
1679+
k_death = 2,
1680+
exp_noise = 1e6)
1681+
1682+
steps_per_day = 1
1683+
R0_change = int_unique$change
1684+
date_R0_change = as.Date(int_unique$dates_change)
1685+
date_contact_matrix_set_change = NULL
1686+
squire_model = squire:::deterministic_model()
1687+
n_particles = 2
1688+
# proposal kernel covriance
1689+
proposal_kernel <- matrix(0.5, ncol=length(pars_init), nrow = length(pars_init))
1690+
diag(proposal_kernel) <- 1
1691+
rownames(proposal_kernel) <- colnames(proposal_kernel) <- names(pars_init)
1692+
1693+
sero_df <- data.frame("samples" = 1000, "sero_pos" = 10,
1694+
"date_start" = as.Date("2020-04-15"),
1695+
"date_end" = as.Date("2020-04-19"))
1696+
# seroconversion data from brazeay report 34
1697+
prob_conversion <- cumsum(dgamma(0:300,shape = 5, rate = 1/2))/max(cumsum(dgamma(0:300,shape = 5, rate = 1/2)))
1698+
sero_det <- cumsum(dweibull(0:300, 3.669807, scale = 143.7046))
1699+
sero_det <- prob_conversion-sero_det
1700+
sero_det[sero_det < 0] <- 0
1701+
sero_det <- sero_det/max(sero_det)
1702+
1703+
pars_obs$sero_df <- sero_df
1704+
pars_obs$sero_det <- sero_det
1705+
1706+
Sys.setenv("SQUIRE_PARALLEL_DEBUG"=TRUE)
1707+
out <- pmcmc(data = data,
1708+
n_mcmc = 5,
1709+
log_likelihood = NULL,
1710+
log_prior = NULL,
1711+
n_particles = 2,
1712+
steps_per_day = steps_per_day,
1713+
output_proposals = FALSE,
1714+
n_chains = 1,
1715+
replicates = 20,
1716+
burnin = 5,
1717+
squire_model = squire_model,
1718+
pars_init = pars_init,
1719+
pars_min = pars_min,
1720+
pars_max = pars_max,
1721+
pars_discrete = pars_discrete,
1722+
pars_obs = pars_obs,
1723+
proposal_kernel = proposal_kernel,
1724+
R0_change = R0_change,
1725+
date_R0_change = date_R0_change,
1726+
country = country)
1727+
1728+
pars_init$rf <- 1
1729+
1730+
out2 <- pmcmc(data = data,
1731+
n_mcmc = 5,
1732+
log_likelihood = NULL,
1733+
log_prior = NULL,
1734+
n_particles = 2,
1735+
steps_per_day = steps_per_day,
1736+
output_proposals = FALSE,
1737+
n_chains = 1,
1738+
replicates = 20,
1739+
burnin = 5,
1740+
squire_model = squire_model,
1741+
pars_init = pars_init,
1742+
pars_min = pars_min,
1743+
pars_max = pars_max,
1744+
pars_discrete = pars_discrete,
1745+
pars_obs = pars_obs,
1746+
proposal_kernel = proposal_kernel,
1747+
R0_change = R0_change,
1748+
date_R0_change = date_R0_change,
1749+
country = country)
1750+
1751+
expect_gt(sum(out$pmcmc_results$results$log_likelihood),
1752+
sum(out2$pmcmc_results$results$log_likelihood))
1753+
1754+
expect_s3_class(plot(out, what = "deaths", particle_fit = TRUE), "gg")
1755+
1756+
})

0 commit comments

Comments
 (0)