|
| 1 | + |
| 2 | + |
| 3 | +library(cmdstanr) |
| 4 | +library(dplyr) |
| 5 | +library(ggplot2) |
| 6 | +library(bayesplot) |
| 7 | +library(tidyr) |
| 8 | +library(brms) |
| 9 | + |
| 10 | + |
| 11 | +get_sld <- function(t, b, g, c, p) { |
| 12 | + b * exp((g * t) - (p/c) * (1 - exp(-c * t))) |
| 13 | +} |
| 14 | + |
| 15 | +pars_mu <- list( |
| 16 | + b = 60, |
| 17 | + g = 0.5, |
| 18 | + c = 0.4, |
| 19 | + p = 0.7, |
| 20 | + sigma = 0.02 |
| 21 | +) |
| 22 | + |
| 23 | +pars_sigma <- list( |
| 24 | + b = 0.1, |
| 25 | + g = 0.1, |
| 26 | + c = 0.1, |
| 27 | + p = 0.1 |
| 28 | +) |
| 29 | + |
| 30 | +N <- 120 |
| 31 | + |
| 32 | +dat_baseline <- tibble( |
| 33 | + pt = sprintf("pt%05d", 1:N), |
| 34 | + b = exp(rnorm(N, log(pars_mu$b), pars_sigma$b)), |
| 35 | + g = exp(rnorm(N, log(pars_mu$g), pars_sigma$g)), |
| 36 | + c = exp(rnorm(N, log(pars_mu$c), pars_sigma$c)), |
| 37 | + p = exp(rnorm(N, log(pars_mu$p), pars_sigma$p)) |
| 38 | +) |
| 39 | + |
| 40 | +dat <- tidyr::crossing( |
| 41 | + pt = dat_baseline$pt, |
| 42 | + t = seq(1, 900, length.out = 8) / 365 |
| 43 | +) |> |
| 44 | + left_join(dat_baseline, by = "pt") |> |
| 45 | + mutate( |
| 46 | + sld_mu = get_sld(t, b, g, c, p), |
| 47 | + sld = rnorm(n(), sld_mu, sld_mu * pars_mu$sigma) |
| 48 | + ) |> |
| 49 | + arrange(pt, t) |> |
| 50 | + mutate(pt = factor(pt)) |
| 51 | + |
| 52 | + |
| 53 | +mod <- cmdstan_model( |
| 54 | + stan_file = here::here("design/debug-cb/B/claret_bruno.stan") |
| 55 | +) |
| 56 | + |
| 57 | +stan_data <- list( |
| 58 | + N_obs = nrow(dat), |
| 59 | + N_pt = N, |
| 60 | + pt_index = as.numeric(dat$pt), |
| 61 | + values = dat$sld, |
| 62 | + times = dat$t |
| 63 | +) |
| 64 | + |
| 65 | + |
| 66 | +fit <- mod$sample( |
| 67 | + data = stan_data, |
| 68 | + chains = 3, |
| 69 | + parallel_chains = 3, |
| 70 | + refresh = 200, |
| 71 | + iter_warmup = 1500, |
| 72 | + iter_sampling = 2000 |
| 73 | +) |
| 74 | + |
| 75 | +fit |
| 76 | +fit$summary() |
| 77 | + |
| 78 | + |
| 79 | + |
| 80 | + |
| 81 | + |
| 82 | +###################### |
| 83 | +# |
| 84 | +# brms implementation |
| 85 | +# |
| 86 | +# |
| 87 | + |
| 88 | + |
| 89 | +# pars_mu <- list( |
| 90 | +# b = 60, |
| 91 | +# g = 0.5, |
| 92 | +# c = 0.4, |
| 93 | +# p = 0.7, |
| 94 | +# sigma = 0.02 |
| 95 | +# ) |
| 96 | + |
| 97 | +bfit <- brm( |
| 98 | + bf( |
| 99 | + value ~ exp(b) * exp( exp(g) * t - exp(p-c) * (1 - exp(- exp(c) * t))), |
| 100 | + b ~ 1 + (1 | pt), |
| 101 | + g ~ 1 + (1 | pt), |
| 102 | + c ~ 1 + (1 | pt), |
| 103 | + p ~ 1 + (1 | pt), |
| 104 | + nl = TRUE |
| 105 | + ), |
| 106 | + data = dat |> select(pt, value = sld, t), |
| 107 | + prior = c( |
| 108 | + prior("normal(log(60), 0.3)", nlpar = "b"), # b intercept |
| 109 | + prior("normal(log(0.5), 0.3)", nlpar = "g"), # g intercept |
| 110 | + prior("normal(log(0.4), 0.3)", nlpar = "c"), # c intercept |
| 111 | + prior("normal(log(0.7), 0.3)", nlpar = "p"), # p intercept |
| 112 | + prior("lognormal(log(0.1), 0.3)", nlpar = "b", class = "sd"), # b random effect sigma |
| 113 | + prior("lognormal(log(0.1), 0.3)", nlpar = "g", class = "sd"), # g random effect sigma |
| 114 | + prior("lognormal(log(0.1), 0.3)", nlpar = "c", class = "sd"), # c random effect sigma |
| 115 | + prior("lognormal(log(0.1), 0.3)", nlpar = "p", class = "sd"), # p random effect sigma |
| 116 | + prior("lognormal(log(0.02), 0.3)", class = "sigma") # overall sigma |
| 117 | + ), |
| 118 | + warmup = 1500, |
| 119 | + iter = 2500, |
| 120 | + chains = 3, |
| 121 | + cores = 3, |
| 122 | + backend = "cmdstanr", |
| 123 | + control = list(adapt_delta = 0.95) |
| 124 | +) |
| 125 | + |
| 126 | + |
| 127 | + |
| 128 | + |
| 129 | +##################### |
| 130 | +# |
| 131 | +# Debugging |
| 132 | +# |
| 133 | +# |
| 134 | + |
| 135 | + |
| 136 | + |
| 137 | +# Plot patient profiles |
| 138 | +pdat <- dat |> filter(pt %in% sample(dat$pt, 5)) |
| 139 | + |
| 140 | +ggplot(data = pdat, aes(x = t, y = sld, group = pt, col = pt)) + |
| 141 | + geom_point() + |
| 142 | + geom_line(aes(y = sld_mu)) |
| 143 | + |
| 144 | + |
| 145 | +# Plottig priors |
| 146 | +plot(density(exp(rnorm(5000, log(0.6), 0.2)))) |
| 147 | + |
| 148 | + |
| 149 | +# Plotting Joint Priors |
| 150 | +N <- 100000 |
| 151 | +mu <- rnorm(N, log(0.6), 0.3) |
| 152 | +sigma <- exp(rnorm(N, log(0.1), 0.3)) |
| 153 | +value <- exp(rnorm(N, mu, sigma)) |
| 154 | + |
| 155 | +pdat <- tibble( |
| 156 | + mu = mu, |
| 157 | + sigma = sigma, |
| 158 | + value = value |
| 159 | +) |
| 160 | + |
| 161 | +ggplot(data = pdat, aes(x =value, y = mu)) + |
| 162 | + geom_bin2d(bins = 300) |
| 163 | + |
| 164 | + |
0 commit comments