Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ The function interface remains unchanged.

## Model changes

- Refactored Stan code for efficiency (@bob-carpenter, #1273).
- MCMC runs are now initialised with parameter values drawn from a distribution that approximates their prior distributions.
- Added an option to compute growth rates using an estimator by Parag et al. (2022) based on total infectiousness rather than new infections, see `growth_method` argument in rt_opts().
- Added support for fitting the susceptible population size.
Expand Down
11 changes: 5 additions & 6 deletions inst/stan/functions/convolve.stan
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ array[] int calc_conv_indices_len(int s, int xlen, int ylen) {
vector convolve_with_rev_pmf(vector x, vector y, int len) {
int xlen = num_elements(x);
int ylen = num_elements(y);
vector[len] z;

if (xlen + ylen - 1 < len) {
reject("convolve_with_rev_pmf: len is longer than x and y convolved");
Expand All @@ -72,16 +71,16 @@ vector convolve_with_rev_pmf(vector x, vector y, int len) {
reject("convolve_with_rev_pmf: len is shorter than x");
}

vector[len] z;

for (s in 1:xlen) {
array[4] int indices = calc_conv_indices_xlen(s, xlen, ylen);
z[s] = dot_product(x[indices[1]:indices[2]], y[indices[3]:indices[4]]);
}

if (len > xlen) {
for (s in (xlen + 1):len) {
array[4] int indices = calc_conv_indices_len(s, xlen, ylen);
z[s] = dot_product(x[indices[1]:indices[2]], y[indices[3]:indices[4]]);
}
for (s in (xlen + 1):len) { // zero iterations unless len > xlen
array[4] int indices = calc_conv_indices_len(s, xlen, ylen);
z[s] = dot_product(x[indices[1]:indices[2]], y[indices[3]:indices[4]]);
}

return z;
Expand Down
53 changes: 31 additions & 22 deletions inst/stan/functions/gaussian_process.stan
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,21 @@ vector diagSPD_EQ(real alpha, real rho, real L, int M) {
return factor * exp(exponent * square(indices));
}

/**
* Index set for M basis functions of length L for Matern kernel.
*
* The function returns pow(pi() / 2 / L * linspaced_vector(M, 1, M), 2),
* or equivalently, square(pi() / (2 * L) * linspaced_vector(M, 1, M)).
*
* @param L Length of the interval
* @param M Number of basis functions
* @return Linearly spaced M-vector
*/
vector matern_indices(int M, real L) {
real factor = pi() / (2 * L);
return square(linspaced_vector(M, factor, factor * M));
}

/**
* Spectral density for 1/2 Matern (Ornstein-Uhlenbeck) kernel
*
Expand All @@ -35,10 +50,8 @@ vector diagSPD_EQ(real alpha, real rho, real L, int M) {
* @ingroup estimates_smoothing
*/
vector diagSPD_Matern12(real alpha, real rho, real L, int M) {
vector[M] indices = linspaced_vector(M, 1, M);
real factor = 2;
vector[M] denom = rho * ((1 / rho)^2 + pow(pi() / 2 / L * indices, 2));
return alpha * sqrt(factor * inv(denom));
vector[M] denom = 1 / rho + rho * matern_indices(M, L);
return alpha * sqrt(2 ./ denom);
}

/**
Expand All @@ -53,10 +66,9 @@ vector diagSPD_Matern12(real alpha, real rho, real L, int M) {
* @ingroup estimates_smoothing
*/
vector diagSPD_Matern32(real alpha, real rho, real L, int M) {
vector[M] indices = linspaced_vector(M, 1, M);
real factor = 2 * alpha * pow(sqrt(3) / rho, 1.5);
vector[M] denom = (sqrt(3) / rho)^2 + pow((pi() / 2 / L) * indices, 2);
return factor * inv(denom);
real factor = 2 * alpha * (sqrt(3) / rho)^1.5;
vector[M] denom = 3 / square(rho) + matern_indices(M, L);
return factor ./ denom;
}

/**
Expand All @@ -71,11 +83,9 @@ vector diagSPD_Matern32(real alpha, real rho, real L, int M) {
* @ingroup estimates_smoothing
*/
vector diagSPD_Matern52(real alpha, real rho, real L, int M) {
vector[M] indices = linspaced_vector(M, 1, M);
real factor = 16 * pow(sqrt(5) / rho, 5);
vector[M] denom =
3 * pow((sqrt(5) / rho)^2 + pow((pi() / 2 / L) * indices, 2), 3);
return alpha * sqrt(factor * inv(denom));
vector[M] denom = 3 * pow(5 / square(rho) + matern_indices(M, L), 3);
return alpha * sqrt(factor ./ denom);
}

/**
Expand All @@ -92,10 +102,11 @@ vector diagSPD_Periodic(real alpha, real rho, int M) {
real a = inv_square(rho);
vector[M] indices = linspaced_vector(M, 1, M);
vector[M] q = exp(
log(alpha) + 0.5 *
(log(2) - a + to_vector(log_modified_bessel_first_kind(indices, a)))
log(alpha) +
0.5 * (log2() - a + log_modified_bessel_first_kind(indices, a))
);
return append_row(q, q);

}

/**
Expand Down Expand Up @@ -129,11 +140,11 @@ matrix PHI(int N, int M, real L, vector x) {
*
* @ingroup estimates_smoothing
*/

matrix PHI_periodic(int N, int M, real w0, vector x) {
matrix[N, M] mw0x = diag_post_multiply(
rep_matrix(w0 * x, M), linspaced_vector(M, 1, M)
);
return append_col(cos(mw0x), sin(mw0x));
row_vector[M] k = linspaced_row_vector(M, 1, M);
matrix[N, M] w0xk = (w0 * x) * k;
return append_col(cos(w0xk), sin(w0xk));
}

/**
Expand All @@ -153,9 +164,7 @@ matrix PHI_periodic(int N, int M, real w0, vector x) {
int setup_noise(int ot_h, int t, int horizon, int estimate_r,
int stationary, int future_fixed, int fixed_from) {
int noise_time = estimate_r > 0 ? (stationary > 0 ? ot_h : ot_h - 1) : t;
int noise_terms =
future_fixed > 0 ? (noise_time - horizon + fixed_from) : noise_time;
return noise_terms;
return future_fixed > 0 ? (noise_time - horizon + fixed_from) : noise_time;
}

/**
Expand Down Expand Up @@ -210,7 +219,7 @@ vector update_gp(matrix PHI, int M, real L, real alpha,
} else if (nu == 2.5) {
diagSPD = diagSPD_Matern52(alpha, rho, L, M);
} else {
reject("nu must be one of 1/2, 3/2 or 5/2; found nu=", nu);
reject("nu must be one of 0.5, 1.5, or 2.5; found nu=", nu);
}
}
return PHI * (diagSPD .* eta);
Expand Down
Loading