Skip to content

Commit 3d7f616

Browse files
author
Craig Gower-Page
authored
Fix Claret Unit test (#441)
1 parent 6b20f37 commit 3d7f616

File tree

11 files changed

+617
-44
lines changed

11 files changed

+617
-44
lines changed

design/debug-cb/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
2+
**/claret_bruno

design/debug-cb/A/README.md

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
2+
3+
4+
5+
6+
## Model Specification
7+
8+
Fixed effects only model, that is we have no patient level random effects. We have no study or arm level hierarchical effects. There is no censoring and we only account for strictly positive time.
9+
10+
11+
$$
12+
\begin{align*}
13+
y_i &\sim \mathcal{N} \left( \mu_i, \mu_i^2 \sigma^2 \right) \\
14+
\\
15+
\mu_i =
16+
b \cdot &\exp \left( g t_{i} - \frac{p}{c} \left( 1 - e^{-c t_{i}} \right) \right)
17+
\end{align*}
18+
$$
19+
20+
- $i$ is the observation index
21+
22+
### Priors
23+
24+
$$
25+
\begin{align*}
26+
b &\sim \text{LogNormal} \left( \right) \\
27+
g &\sim \text{LogNormal} \left( \right) \\
28+
p &\sim \text{LogNormal} \left( \right) \\
29+
c &\sim \text{LogNormal} \left( \right) \\
30+
\sigma &\sim \text{LogNormal} \left( \right)
31+
\end{align*}
32+
$$
33+
34+
35+
36+
37+

design/debug-cb/A/claret_bruno.R

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
2+
3+
library(cmdstanr)
4+
library(dplyr)
5+
library(ggplot2)
6+
library(bayesplot)
7+
8+
9+
get_sld <- function(t, b, g, c, p) {
10+
b * exp((g * t) - (p/c) * (1 - exp(-c * t)))
11+
}
12+
13+
pars <- list(
14+
b = 60,
15+
g = 0.5,
16+
c = 0.4,
17+
p = 0.7,
18+
sigma = 0.002
19+
)
20+
21+
dat <- tibble(
22+
t = seq(1, 900, by = 5) / 365,
23+
sld_mu = get_sld(t, pars$b, pars$g, pars$c, pars$p),
24+
sld = rnorm(length(t), sld_mu, sld_mu * pars$sigma)
25+
)
26+
27+
ggplot(data = dat, aes(x = t, y = sld)) +
28+
geom_point() +
29+
geom_line(aes(y = sld_mu), color = "red")
30+
31+
32+
mod <- cmdstan_model(
33+
stan_file = here::here("design/debug-cb/A/claret_bruno.stan")
34+
)
35+
36+
stan_data <- list(
37+
N = nrow(dat),
38+
values = dat$sld,
39+
times = dat$t
40+
)
41+
42+
fit <- mod$sample(
43+
data = stan_data,
44+
chains = 2,
45+
parallel_chains = 2,
46+
refresh = 200,
47+
iter_warmup = 1000,
48+
iter_sampling = 2000
49+
)
50+
51+
fit$summary()
52+
53+
54+
55+
56+
57+
58+
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
2+
functions {
3+
vector claret_bruno_mu(vector times, real b, real c, real p, real g) {
4+
vector[rows(times)] values;
5+
values = b * exp(
6+
(g .* times) - (
7+
(p/c) .* (1 - exp(-c .* times))
8+
)
9+
);
10+
return values;
11+
}
12+
}
13+
14+
15+
data {
16+
int <lower=0> N;
17+
vector[N] values;
18+
vector[N] times;
19+
}
20+
21+
parameters {
22+
real <lower=0> b;
23+
real <lower=0> c;
24+
real <lower=0> p;
25+
real <lower=0> g;
26+
real <lower=0> sigma;
27+
}
28+
29+
transformed parameters {
30+
vector[N] mu = claret_bruno_mu(times, b, c, p, g);
31+
}
32+
33+
// pars <- list(
34+
// b = 60,
35+
// g = 0.5,
36+
// c = 0.4,
37+
// p = 0.7,
38+
// sigma = 0.004
39+
// )
40+
41+
model {
42+
b ~ lognormal(log(60), 0.5);
43+
c ~ lognormal(log(0.5), 0.5);
44+
p ~ normal(log(0.4), 0.5);
45+
g ~ lognormal(log(0.7), 0.5);
46+
sigma ~ lognormal(log(0.004), 0.5);
47+
values ~ normal(mu, mu * sigma);
48+
}
49+
50+
51+
52+

design/debug-cb/B/README.md

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
2+
3+
4+
5+
6+
## Model Specification
7+
8+
Simple by-patient Random effects model. We have no study or arm level hierarchical effects. There is no censoring and we only account for strictly positive time.
9+
10+
11+
$$
12+
\begin{align*}
13+
y_{ij} &\sim \mathcal{N} \left( \mu_{ij}, \mu_{ij}^2 \sigma^2 \right) \\
14+
\\
15+
\mu_{ij} &=
16+
b_i \cdot \exp \left( g_i t_{ij} - \frac{p_i}{c_i} \left( 1 - e^{-c_i t_{ij}} \right) \right)
17+
\end{align*}
18+
$$
19+
20+
21+
- $i$ is the patient index
22+
- $j$ is the time index
23+
24+
### Priors
25+
26+
$$
27+
\begin{align*}
28+
b_i &\sim \text{LogNormal} \left( log(\mu_b) , \sigma_b \right) \\
29+
g_i &\sim \text{LogNormal} \left( log(\mu_g) , \sigma_g \right) \\
30+
p_i &\sim \text{LogNormal} \left( log(\mu_p) , \sigma_p \right) \\
31+
c_i &\sim \text{LogNormal} \left( log(\mu_c) , \sigma_c \right) \\
32+
\\
33+
\mu_b &\sim \text{LogNormal} \left( \right) \\
34+
\mu_g &\sim \text{LogNormal} \left( \right) \\
35+
\mu_p &\sim \text{LogNormal} \left( \right) \\
36+
\mu_c &\sim \text{LogNormal} \left( \right) \\
37+
\\
38+
\sigma_b &\sim \text{LogNormal} \left( \right) \\
39+
\sigma_g &\sim \text{LogNormal} \left( \right) \\
40+
\sigma_p &\sim \text{LogNormal} \left( \right) \\
41+
\sigma_c &\sim \text{LogNormal} \left( \right) \\
42+
\\
43+
\sigma &\sim \text{LogNormal} \left( \right) \\
44+
\end{align*}
45+
$$
46+
47+
48+
49+
50+

design/debug-cb/B/claret_bruno.R

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
2+
3+
library(cmdstanr)
4+
library(dplyr)
5+
library(ggplot2)
6+
library(bayesplot)
7+
library(tidyr)
8+
library(brms)
9+
10+
11+
get_sld <- function(t, b, g, c, p) {
12+
b * exp((g * t) - (p/c) * (1 - exp(-c * t)))
13+
}
14+
15+
pars_mu <- list(
16+
b = 60,
17+
g = 0.5,
18+
c = 0.4,
19+
p = 0.7,
20+
sigma = 0.02
21+
)
22+
23+
pars_sigma <- list(
24+
b = 0.1,
25+
g = 0.1,
26+
c = 0.1,
27+
p = 0.1
28+
)
29+
30+
N <- 120
31+
32+
dat_baseline <- tibble(
33+
pt = sprintf("pt%05d", 1:N),
34+
b = exp(rnorm(N, log(pars_mu$b), pars_sigma$b)),
35+
g = exp(rnorm(N, log(pars_mu$g), pars_sigma$g)),
36+
c = exp(rnorm(N, log(pars_mu$c), pars_sigma$c)),
37+
p = exp(rnorm(N, log(pars_mu$p), pars_sigma$p))
38+
)
39+
40+
dat <- tidyr::crossing(
41+
pt = dat_baseline$pt,
42+
t = seq(1, 900, length.out = 8) / 365
43+
) |>
44+
left_join(dat_baseline, by = "pt") |>
45+
mutate(
46+
sld_mu = get_sld(t, b, g, c, p),
47+
sld = rnorm(n(), sld_mu, sld_mu * pars_mu$sigma)
48+
) |>
49+
arrange(pt, t) |>
50+
mutate(pt = factor(pt))
51+
52+
53+
mod <- cmdstan_model(
54+
stan_file = here::here("design/debug-cb/B/claret_bruno.stan")
55+
)
56+
57+
stan_data <- list(
58+
N_obs = nrow(dat),
59+
N_pt = N,
60+
pt_index = as.numeric(dat$pt),
61+
values = dat$sld,
62+
times = dat$t
63+
)
64+
65+
66+
fit <- mod$sample(
67+
data = stan_data,
68+
chains = 3,
69+
parallel_chains = 3,
70+
refresh = 200,
71+
iter_warmup = 1500,
72+
iter_sampling = 2000
73+
)
74+
75+
fit
76+
fit$summary()
77+
78+
79+
80+
81+
82+
######################
83+
#
84+
# brms implementation
85+
#
86+
#
87+
88+
89+
# pars_mu <- list(
90+
# b = 60,
91+
# g = 0.5,
92+
# c = 0.4,
93+
# p = 0.7,
94+
# sigma = 0.02
95+
# )
96+
97+
bfit <- brm(
98+
bf(
99+
value ~ exp(b) * exp( exp(g) * t - exp(p-c) * (1 - exp(- exp(c) * t))),
100+
b ~ 1 + (1 | pt),
101+
g ~ 1 + (1 | pt),
102+
c ~ 1 + (1 | pt),
103+
p ~ 1 + (1 | pt),
104+
nl = TRUE
105+
),
106+
data = dat |> select(pt, value = sld, t),
107+
prior = c(
108+
prior("normal(log(60), 0.3)", nlpar = "b"), # b intercept
109+
prior("normal(log(0.5), 0.3)", nlpar = "g"), # g intercept
110+
prior("normal(log(0.4), 0.3)", nlpar = "c"), # c intercept
111+
prior("normal(log(0.7), 0.3)", nlpar = "p"), # p intercept
112+
prior("lognormal(log(0.1), 0.3)", nlpar = "b", class = "sd"), # b random effect sigma
113+
prior("lognormal(log(0.1), 0.3)", nlpar = "g", class = "sd"), # g random effect sigma
114+
prior("lognormal(log(0.1), 0.3)", nlpar = "c", class = "sd"), # c random effect sigma
115+
prior("lognormal(log(0.1), 0.3)", nlpar = "p", class = "sd"), # p random effect sigma
116+
prior("lognormal(log(0.02), 0.3)", class = "sigma") # overall sigma
117+
),
118+
warmup = 1500,
119+
iter = 2500,
120+
chains = 3,
121+
cores = 3,
122+
backend = "cmdstanr",
123+
control = list(adapt_delta = 0.95)
124+
)
125+
126+
127+
128+
129+
#####################
130+
#
131+
# Debugging
132+
#
133+
#
134+
135+
136+
137+
# Plot patient profiles
138+
pdat <- dat |> filter(pt %in% sample(dat$pt, 5))
139+
140+
ggplot(data = pdat, aes(x = t, y = sld, group = pt, col = pt)) +
141+
geom_point() +
142+
geom_line(aes(y = sld_mu))
143+
144+
145+
# Plottig priors
146+
plot(density(exp(rnorm(5000, log(0.6), 0.2))))
147+
148+
149+
# Plotting Joint Priors
150+
N <- 100000
151+
mu <- rnorm(N, log(0.6), 0.3)
152+
sigma <- exp(rnorm(N, log(0.1), 0.3))
153+
value <- exp(rnorm(N, mu, sigma))
154+
155+
pdat <- tibble(
156+
mu = mu,
157+
sigma = sigma,
158+
value = value
159+
)
160+
161+
ggplot(data = pdat, aes(x =value, y = mu)) +
162+
geom_bin2d(bins = 300)
163+
164+

0 commit comments

Comments
 (0)