Skip to content

Commit e0ad884

Browse files
committed
Enable setting priors for linear models
1 parent d36e37a commit e0ad884

14 files changed

+184
-22
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Package: BayesERtools
22
Type: Package
33
Title: Bayesian Exposure-Response Analysis Tools
4-
Version: 0.2.1
4+
Version: 0.2.1.1001
55
Authors@R:
66
c(person("Kenta", "Yoshida", , "yoshida.kenta.6@gmail.com", role = c("aut", "cre"),
77
comment = c(ORCID = "0000-0003-4967-3831")),

NAMESPACE

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ S3method(as_draws_df,ermod)
66
S3method(as_draws_list,ermod)
77
S3method(as_draws_matrix,ermod)
88
S3method(as_draws_rvars,ermod)
9+
S3method(bayestestR::p_direction,ermod_bin)
910
S3method(coef,ermod)
1011
S3method(extract_data,ermod)
1112
S3method(extract_data,ersim)
@@ -26,7 +27,6 @@ S3method(extract_var_selected,ermod_cov_sel)
2627
S3method(loo,ermod)
2728
S3method(loo,ermod_bin_emax)
2829
S3method(loo,ermod_emax)
29-
S3method(p_direction,ermod_bin)
3030
S3method(plot,ermod_bin)
3131
S3method(plot,ermod_cov_sel)
3232
S3method(plot,ermod_exp_sel)
@@ -41,6 +41,7 @@ S3method(print,ermod)
4141
S3method(print,ermod_cov_sel)
4242
S3method(print,ermod_exp_sel)
4343
S3method(print,kfold_cv_ermod)
44+
S3method(prior_summary,ermod)
4445
S3method(summary,ermod)
4546
export(.dev_ermod_refmodel)
4647
export(.select_cov_projpred)
@@ -74,7 +75,6 @@ export(extract_var_exposure)
7475
export(extract_var_resp)
7576
export(extract_var_selected)
7677
export(loo)
77-
export(p_direction)
7878
export(plot_coveff)
7979
export(plot_er)
8080
export(plot_er_exp_sel)
@@ -90,11 +90,11 @@ export(sim_er_curve)
9090
export(sim_er_curve_marg)
9191
export(sim_er_new_exp)
9292
export(sim_er_new_exp_marg)
93-
importFrom(bayestestR,p_direction)
9493
importFrom(loo,loo)
9594
importFrom(posterior,as_draws)
9695
importFrom(posterior,as_draws_array)
9796
importFrom(posterior,as_draws_df)
9897
importFrom(posterior,as_draws_list)
9998
importFrom(posterior,as_draws_matrix)
10099
importFrom(posterior,as_draws_rvars)
100+
importFrom(rstanarm,prior_summary)

NEWS.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
# BayesERtools 0.2.2 (in development)
2+
3+
## Minor changes
4+
5+
* Enable setting the prior distribution for linear models
6+
* Update package dependencies
7+
18
# BayesERtools 0.2.1
29

310
* Update package dependency

R/dev_ermod_lin.R

Lines changed: 76 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ dev_ermod_bin <- function(
3737
var_resp,
3838
var_exposure,
3939
var_cov = NULL,
40+
prior = rstanarm::default_prior_coef(stats::binomial()),
41+
prior_intercept = rstanarm::default_prior_intercept(stats::binomial()),
4042
verbosity_level = 1,
4143
chains = 4,
4244
iter = 2000) {
@@ -68,6 +70,8 @@ dev_ermod_bin <- function(
6870
formula_final,
6971
family = stats::binomial(),
7072
data = data,
73+
prior = prior,
74+
prior_intercept = prior_intercept,
7175
QR = dplyr::if_else(length(var_full) > 1, TRUE, FALSE),
7276
refresh = refresh,
7377
chains = chains,
@@ -116,9 +120,18 @@ dev_ermod_bin_exp_sel <- function(
116120
data,
117121
var_resp,
118122
var_exp_candidates,
123+
prior = rstanarm::default_prior_coef(stats::binomial()),
124+
prior_intercept = rstanarm::default_prior_intercept(stats::binomial()),
119125
verbosity_level = 1,
120126
chains = 4,
121127
iter = 2000) {
128+
fun_dev_ermod <-
129+
purrr::partial(
130+
dev_ermod_bin,
131+
prior = prior,
132+
prior_intercept = prior_intercept
133+
)
134+
122135
l_out <-
123136
.dev_ermod_exp_sel(
124137
data = data,
@@ -127,7 +140,7 @@ dev_ermod_bin_exp_sel <- function(
127140
verbosity_level = verbosity_level,
128141
chains = chains,
129142
iter = iter,
130-
fun_dev_ermod = dev_ermod_bin
143+
fun_dev_ermod = fun_dev_ermod
131144
)
132145

133146
new_ermod_bin_exp_sel(l_out)
@@ -185,9 +198,18 @@ dev_ermod_bin_cov_sel <- function(
185198
validate_search = FALSE,
186199
nterms_max = NULL,
187200
.reduce_obj_size = TRUE,
201+
prior = rstanarm::default_prior_coef(stats::binomial()),
202+
prior_intercept = rstanarm::default_prior_intercept(stats::binomial()),
188203
verbosity_level = 1,
189204
chains = 4,
190205
iter = 2000) {
206+
fun_dev_ermod <-
207+
purrr::partial(
208+
dev_ermod_bin,
209+
prior = prior,
210+
prior_intercept = prior_intercept
211+
)
212+
191213
ll <- .dev_ermod_cov_sel(
192214
data = data,
193215
var_resp = var_resp,
@@ -201,8 +223,10 @@ dev_ermod_bin_cov_sel <- function(
201223
verbosity_level = verbosity_level,
202224
chains = chains,
203225
iter = iter,
204-
fun_dev_ermod = dev_ermod_bin,
205-
fun_family = quote(stats::binomial())
226+
fun_dev_ermod = fun_dev_ermod,
227+
fun_family = quote(stats::binomial()),
228+
prior = prior,
229+
prior_intercept = prior_intercept
206230
)
207231

208232
with(ll, new_ermod_bin_cov_sel(
@@ -243,6 +267,9 @@ dev_ermod_lin <- function(
243267
var_resp,
244268
var_exposure,
245269
var_cov = NULL,
270+
prior = rstanarm::default_prior_coef(stats::binomial()),
271+
prior_intercept = rstanarm::default_prior_intercept(stats::binomial()),
272+
prior_aux = rstanarm::exponential(autoscale = TRUE),
246273
verbosity_level = 1,
247274
chains = 4,
248275
iter = 2000) {
@@ -272,6 +299,9 @@ dev_ermod_lin <- function(
272299
formula_final,
273300
family = stats::gaussian(),
274301
data = data,
302+
prior = prior,
303+
prior_intercept = prior_intercept,
304+
prior_aux = prior_aux,
275305
QR = dplyr::if_else(length(var_full) > 1, TRUE, FALSE),
276306
refresh = refresh,
277307
chains = chains,
@@ -308,9 +338,20 @@ dev_ermod_lin_exp_sel <- function(
308338
data,
309339
var_resp,
310340
var_exp_candidates,
341+
prior = rstanarm::default_prior_coef(stats::binomial()),
342+
prior_intercept = rstanarm::default_prior_intercept(stats::binomial()),
343+
prior_aux = rstanarm::exponential(autoscale = TRUE),
311344
verbosity_level = 1,
312345
chains = 4,
313346
iter = 2000) {
347+
fun_dev_ermod <-
348+
purrr::partial(
349+
dev_ermod_lin,
350+
prior = prior,
351+
prior_intercept = prior_intercept,
352+
prior_aux = prior_aux
353+
)
354+
314355
l_out <-
315356
.dev_ermod_exp_sel(
316357
data = data,
@@ -319,7 +360,7 @@ dev_ermod_lin_exp_sel <- function(
319360
verbosity_level = verbosity_level,
320361
chains = chains,
321362
iter = iter,
322-
fun_dev_ermod = dev_ermod_lin
363+
fun_dev_ermod = fun_dev_ermod
323364
)
324365

325366
new_ermod_lin_exp_sel(l_out)
@@ -352,9 +393,20 @@ dev_ermod_lin_cov_sel <- function(
352393
validate_search = FALSE,
353394
nterms_max = NULL,
354395
.reduce_obj_size = TRUE,
396+
prior = rstanarm::default_prior_coef(stats::binomial()),
397+
prior_intercept = rstanarm::default_prior_intercept(stats::binomial()),
398+
prior_aux = rstanarm::exponential(autoscale = TRUE),
355399
verbosity_level = 1,
356400
chains = 4,
357401
iter = 2000) {
402+
fun_dev_ermod <-
403+
purrr::partial(
404+
dev_ermod_lin,
405+
prior = prior,
406+
prior_intercept = prior_intercept,
407+
prior_aux = prior_aux
408+
)
409+
358410
ll <- .dev_ermod_cov_sel(
359411
data = data,
360412
var_resp = var_resp,
@@ -368,8 +420,11 @@ dev_ermod_lin_cov_sel <- function(
368420
verbosity_level = verbosity_level,
369421
chains = chains,
370422
iter = iter,
371-
fun_dev_ermod = dev_ermod_lin,
372-
fun_family = quote(stats::gaussian())
423+
fun_dev_ermod = fun_dev_ermod,
424+
fun_family = quote(stats::gaussian()),
425+
prior = prior,
426+
prior_intercept = prior_intercept,
427+
prior_aux = prior_aux
373428
)
374429

375430
with(ll, new_ermod_lin_cov_sel(
@@ -468,7 +523,10 @@ dev_ermod_lin_cov_sel <- function(
468523
chains = 4,
469524
iter = 2000,
470525
fun_dev_ermod,
471-
fun_family) {
526+
fun_family,
527+
prior = rstanarm::default_prior_coef(stats::binomial()),
528+
prior_intercept = rstanarm::default_prior_intercept(stats::binomial()),
529+
prior_aux = rstanarm::exponential(autoscale = TRUE)) {
472530
stopifnot(verbosity_level %in% c(0, 1, 2, 3))
473531

474532
rlang::check_installed("projpred")
@@ -492,7 +550,10 @@ dev_ermod_lin_cov_sel <- function(
492550
var_cov_candidates = var_cov_candidates,
493551
verbosity_level = verbosity_level,
494552
chains = chains, iter = iter,
495-
fun_family = fun_family
553+
fun_family = fun_family,
554+
prior = prior,
555+
prior_intercept = prior_intercept,
556+
prior_aux = prior_aux
496557
)
497558

498559
if (verbosity_level >= 1) cli::cli_h2("Step 2: Variable selection")
@@ -567,7 +628,10 @@ NULL
567628
.dev_ermod_refmodel <- function(
568629
data, var_resp, var_exposure, var_cov_candidates,
569630
verbosity_level = 1, chains = 4, iter = 2000,
570-
fun_family = quote(stats::binomial())) {
631+
fun_family = quote(stats::binomial()),
632+
prior = rstanarm::default_prior_coef(stats::binomial()),
633+
prior_intercept = rstanarm::default_prior_intercept(stats::binomial()),
634+
prior_aux = rstanarm::exponential(autoscale = TRUE)) {
571635
stopifnot(verbosity_level %in% c(0, 1, 2, 3))
572636
refresh <- dplyr::if_else(verbosity_level >= 3, iter %/% 4, 0)
573637

@@ -592,7 +656,9 @@ NULL
592656
rlang::call2(rstanarm::stan_glm,
593657
formula = formula_full,
594658
family = fun_family, data = quote(data), QR = TRUE,
595-
refresh = refresh, chains = chains, iter = iter
659+
refresh = refresh, chains = chains, iter = iter,
660+
prior = prior, prior_intercept = prior_intercept,
661+
prior_aux = prior_aux
596662
)
597663
fit_ref <- eval(call_fit_ref)
598664

R/ermod-methods.R

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ plot.ermod_bin <- function(x, show_orig_data = FALSE, ...) {
5555
#' @export
5656
#' @rdname ermod_method
5757
coef.ermod <- function(object, ...) {
58-
if (inherits(object, "ermod_emax")) {
59-
stop("coef() not supported for ermod_emax object")
58+
if (!inherits(object, c("ermod_bin", "ermod_lin"))) {
59+
stop("coef() only supported for linear models")
6060
}
6161

6262
stats::coef(object$mod, ...)
@@ -65,8 +65,8 @@ coef.ermod <- function(object, ...) {
6565
#' @export
6666
#' @rdname ermod_method
6767
summary.ermod <- function(object, ...) {
68-
if (inherits(object, "ermod_emax")) {
69-
stop("summary() not supported for ermod_emax object")
68+
if (!inherits(object, c("ermod_bin", "ermod_lin"))) {
69+
stop("summary() only supported for linear models")
7070
}
7171

7272
summary(object$mod, ...)
@@ -223,7 +223,7 @@ extract_var_selected.ermod_cov_sel <- function(x) x$var_selected
223223
#' credible interval (.lower, .upper)
224224
#'
225225
extract_coef_exp_ci <- function(x, ci_width = 0.95) {
226-
# Check that input x is ermod object
226+
# Check that input x is linear ermod object
227227
if (!inherits(x, c("ermod_bin", "ermod_lin"))) {
228228
stop("extract_coef_exp_ci() only supported for linear models")
229229
}
@@ -367,3 +367,22 @@ as_draws_matrix.ermod <- function(x, ...) {
367367
as_draws_rvars.ermod <- function(x, ...) {
368368
posterior::as_draws_rvars(x$mod, ...)
369369
}
370+
371+
# prior_summary --------------------------------------------------------------
372+
#' Summarize the priors used for linear or linear logistic regression models
373+
#'
374+
#' See [rstanarm::prior_summary()] for details.
375+
#'
376+
#' @export
377+
#' @rdname prior_summary
378+
#' @importFrom rstanarm prior_summary
379+
#' @return An object of class `prior_summary.stanreg`
380+
#'
381+
prior_summary.ermod <- function(object, ...) {
382+
# Check that input x is linear ermod object
383+
if (!inherits(object, c("ermod_bin", "ermod_lin"))) {
384+
stop("prior_summary.ermod() only supported for linear models")
385+
}
386+
387+
rstanarm::prior_summary(object$mod, ...)
388+
}

R/p_direction.R

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,4 +80,3 @@ p_direction.ermod_bin <- function(
8080
.if_run_ex_p_dir <- function() {
8181
requireNamespace("bayestestR", quietly = TRUE)
8282
}
83-

R/sim_ermod.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ sim_er <- function(
181181
#' )
182182
#'
183183
#' ersim_new_exp_med_qi
184-
#'}
184+
#' }
185185
#'
186186
sim_er_new_exp <- function(
187187
ermod,

man/dev_ermod_bin.Rd

Lines changed: 5 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/dev_ermod_bin_cov_functions.Rd

Lines changed: 4 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/dev_ermod_bin_cov_sel.Rd

Lines changed: 5 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)