Skip to content

Commit 0e0c61b

Browse files
Add inclusion probability and (curated) group indicators to results$arguments for bgmCompare
1 parent 769e8c1 commit 0e0c61b

File tree

3 files changed

+13
-6
lines changed

3 files changed

+13
-6
lines changed

R/bgm.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@
239239
#' @param target_accept Numeric between 0 and 1. Target acceptance rate for
240240
#' the sampler. Defaults are set automatically if not supplied:
241241
#' \code{0.44} for adaptive Metropolis, \code{0.65} for HMC,
242-
#' and \code{0.60} for NUTS.
242+
#' and \code{0.80} for NUTS.
243243
#'
244244
#' @param hmc_num_leapfrogs Integer. Number of leapfrog steps for Hamiltonian
245245
#' Monte Carlo. Must be positive. Default: \code{100}.
@@ -418,7 +418,7 @@ bgm = function(
418418
} else if(update_method == "hamiltonian-mc") {
419419
target_accept = 0.65
420420
} else if(update_method == "nuts") {
421-
target_accept = 0.60
421+
target_accept = 0.80
422422
}
423423
}
424424

R/bgmCompare.R

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@
112112
#' \code{"adaptive-metropolis"}, \code{"hamiltonian-mc"}, or \code{"nuts"}.
113113
#' Default: \code{"nuts"}.
114114
#' @param target_accept Numeric between 0 and 1. Target acceptance rate.
115-
#' Defaults: 0.44 (Metropolis), 0.65 (HMC), 0.60 (NUTS).
115+
#' Defaults: 0.44 (Metropolis), 0.65 (HMC), 0.80 (NUTS).
116116
#' @param hmc_num_leapfrogs Integer. Leapfrog steps for HMC. Default: \code{100}.
117117
#' @param nuts_max_depth Integer. Maximum tree depth for NUTS. Default: \code{10}.
118118
#' @param learn_mass_matrix Logical. If \code{TRUE}, adapt the mass matrix
@@ -314,7 +314,7 @@ bgmCompare = function(
314314
} else if(update_method == "hamiltonian-mc") {
315315
target_accept = 0.65
316316
} else if(update_method == "nuts") {
317-
target_accept = 0.60
317+
target_accept = 0.80
318318
}
319319
}
320320

@@ -524,6 +524,7 @@ bgmCompare = function(
524524
num_categories = num_categories,
525525
is_ordinal_variable = ordinal_variable,
526526
num_groups = num_groups,
527+
group = sorted_group,
527528
iter = iter,
528529
warmup = warmup,
529530
main_effect_indices = main_effect_indices,
@@ -533,6 +534,7 @@ bgmCompare = function(
533534
difference_prior = difference_prior,
534535
difference_selection_alpha = beta_bernoulli_alpha,
535536
difference_selection_beta = beta_bernoulli_beta,
537+
inclusion_probability = model$inclusion_probability_difference,
536538
pairwise_scale = pairwise_scale,
537539
difference_scale = difference_scale,
538540
update_method = update_method,
@@ -560,6 +562,7 @@ bgmCompare = function(
560562
num_categories = num_categories,
561563
is_ordinal_variable = ordinal_variable,
562564
num_groups = num_groups,
565+
group = sorted_group,
563566
iter = iter,
564567
warmup = warmup,
565568
main_effect_indices = main_effect_indices,
@@ -569,6 +572,7 @@ bgmCompare = function(
569572
difference_prior = difference_prior,
570573
difference_selection_alpha = beta_bernoulli_alpha,
571574
difference_selection_beta = beta_bernoulli_beta,
575+
inclusion_probability = model$inclusion_probability_difference,
572576
pairwise_scale = pairwise_scale,
573577
difference_scale = difference_scale,
574578
update_method = update_method,

R/output_utils.R

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,10 +275,11 @@ generate_param_names_bgmCompare = function(
275275

276276
prepare_output_bgmCompare = function(
277277
out, observations, num_categories, is_ordinal_variable,
278-
num_groups, iter, warmup,
278+
num_groups, group, iter, warmup,
279279
main_effect_indices, pairwise_effect_indices,
280280
data_columnnames, difference_selection,
281281
difference_prior, difference_selection_alpha, difference_selection_beta,
282+
inclusion_probability,
282283
pairwise_scale, difference_scale,
283284
update_method, target_accept, nuts_max_depth, hmc_num_leapfrogs,
284285
learn_mass_matrix, num_chains, projection
@@ -297,6 +298,7 @@ prepare_output_bgmCompare = function(
297298
difference_prior = difference_prior,
298299
difference_selection_alpha = difference_selection_alpha,
299300
difference_selection_beta = difference_selection_beta,
301+
inclusion_probability = inclusion_probability,
300302
version = packageVersion("bgms"),
301303
update_method = update_method,
302304
target_accept = target_accept,
@@ -308,7 +310,8 @@ prepare_output_bgmCompare = function(
308310
data_columnnames = data_columnnames,
309311
projection = projection,
310312
num_categories = num_categories,
311-
is_ordinal_variable = is_ordinal_variable
313+
is_ordinal_variable = is_ordinal_variable,
314+
group = group
312315
)
313316

314317
# --- parameter names

0 commit comments

Comments
 (0)