Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 127 additions & 109 deletions R/bgm.R
Original file line number Diff line number Diff line change
Expand Up @@ -352,64 +352,65 @@
#'
#' @export
bgm = function(
x,
variable_type = "ordinal",
baseline_category,
iter = 1e3,
warmup = 1e3,
pairwise_scale = 2.5,
main_alpha = 0.5,
main_beta = 0.5,
edge_selection = TRUE,
edge_prior = c("Bernoulli", "Beta-Bernoulli", "Stochastic-Block"),
inclusion_probability = 0.5,
beta_bernoulli_alpha = 1,
beta_bernoulli_beta = 1,
beta_bernoulli_alpha_between = 1,
beta_bernoulli_beta_between = 1,
dirichlet_alpha = 1,
lambda = 1,
na_action = c("listwise", "impute"),
update_method = c("nuts", "adaptive-metropolis", "hamiltonian-mc"),
target_accept,
hmc_num_leapfrogs = 100,
nuts_max_depth = 10,
learn_mass_matrix = FALSE,
chains = 4,
cores = parallel::detectCores(),
display_progress = c("per-chain", "total", "none"),
seed = NULL,
interaction_scale,
burnin,
save,
threshold_alpha,
threshold_beta
x,
variable_type = "ordinal",
baseline_category,
iter = 1e3,
warmup = 1e3,
pairwise_scale = 2.5,
main_alpha = 0.5,
main_beta = 0.5,
edge_selection = TRUE,
edge_prior = c("Bernoulli", "Beta-Bernoulli", "Stochastic-Block"),
inclusion_probability = 0.5,
beta_bernoulli_alpha = 1,
beta_bernoulli_beta = 1,
beta_bernoulli_alpha_between = 1,
beta_bernoulli_beta_between = 1,
dirichlet_alpha = 1,
lambda = 1,
na_action = c("listwise", "impute"),
update_method = c("nuts", "adaptive-metropolis", "hamiltonian-mc"),
target_accept,
hmc_num_leapfrogs = 100,
nuts_max_depth = 10,
learn_mass_matrix = FALSE,
chains = 4,
cores = parallel::detectCores(),
display_progress = c("per-chain", "total", "none"),
seed = NULL,
interaction_scale,
burnin,
save,
threshold_alpha,
threshold_beta
) {
if (hasArg(interaction_scale)) {
if(hasArg(interaction_scale)) {
lifecycle::deprecate_warn("0.1.6.0", "bgm(interaction_scale =)", "bgm(pairwise_scale =)")
if (!hasArg(pairwise_scale)) {
if(!hasArg(pairwise_scale)) {
pairwise_scale = interaction_scale
}
}

if (hasArg(burnin)) {
if(hasArg(burnin)) {
lifecycle::deprecate_warn("0.1.6.0", "bgm(burnin =)", "bgm(warmup =)")
if (!hasArg(warmup)) {
if(!hasArg(warmup)) {
warmup = burnin
}
}

if (hasArg(save)) {
if(hasArg(save)) {
lifecycle::deprecate_warn("0.1.6.0", "bgm(save =)")
}

if (hasArg(threshold_alpha) || hasArg(threshold_beta)) {
lifecycle::deprecate_warn("0.1.6.0",
"bgm(threshold_alpha =, threshold_beta =)",
"bgm(main_alpha =, main_beta =)"
if(hasArg(threshold_alpha) || hasArg(threshold_beta)) {
lifecycle::deprecate_warn(
"0.1.6.0",
"bgm(threshold_alpha =, threshold_beta =)",
"bgm(main_alpha =, main_beta =)"
)
if (!hasArg(main_alpha)) main_alpha = threshold_alpha
if (!hasArg(main_beta)) main_beta = threshold_beta
if(!hasArg(main_alpha)) main_alpha = threshold_alpha
if(!hasArg(main_beta)) main_beta = threshold_beta
}

# Check update method
Expand All @@ -430,39 +431,45 @@ bgm = function(
}
}

#Check data input ------------------------------------------------------------
if(!inherits(x, what = "matrix") && !inherits(x, what = "data.frame"))
# Check data input ------------------------------------------------------------
if(!inherits(x, what = "matrix") && !inherits(x, what = "data.frame")) {
stop("The input x needs to be a matrix or dataframe.")
if(inherits(x, what = "data.frame"))
}
if(inherits(x, what = "data.frame")) {
x = data.matrix(x)
if(ncol(x) < 2)
}
if(ncol(x) < 2) {
stop("The matrix x should have more than one variable (columns).")
if(nrow(x) < 2)
}
if(nrow(x) < 2) {
stop("The matrix x should have more than one observation (rows).")
}

#Check model input -----------------------------------------------------------
model = check_model(x = x,
variable_type = variable_type,
baseline_category = baseline_category,
pairwise_scale = pairwise_scale,
main_alpha = main_alpha,
main_beta = main_beta,
edge_selection = edge_selection,
edge_prior = edge_prior,
inclusion_probability = inclusion_probability,
beta_bernoulli_alpha = beta_bernoulli_alpha,
beta_bernoulli_beta = beta_bernoulli_beta,
beta_bernoulli_alpha_between = beta_bernoulli_alpha_between,
beta_bernoulli_beta_between = beta_bernoulli_beta_between,
dirichlet_alpha = dirichlet_alpha,
lambda = lambda)
# Check model input -----------------------------------------------------------
model = check_model(
x = x,
variable_type = variable_type,
baseline_category = baseline_category,
pairwise_scale = pairwise_scale,
main_alpha = main_alpha,
main_beta = main_beta,
edge_selection = edge_selection,
edge_prior = edge_prior,
inclusion_probability = inclusion_probability,
beta_bernoulli_alpha = beta_bernoulli_alpha,
beta_bernoulli_beta = beta_bernoulli_beta,
beta_bernoulli_alpha_between = beta_bernoulli_alpha_between,
beta_bernoulli_beta_between = beta_bernoulli_beta_between,
dirichlet_alpha = dirichlet_alpha,
lambda = lambda
)

# check hyperparameters input
# If user left them NULL, pass -1 to C++ (means: ignore between prior)
if (is.null(beta_bernoulli_alpha_between) && is.null(beta_bernoulli_beta_between)) {
if(is.null(beta_bernoulli_alpha_between) && is.null(beta_bernoulli_beta_between)) {
beta_bernoulli_alpha_between <- -1.0
beta_bernoulli_beta_between <- -1.0
} else if (is.null(beta_bernoulli_alpha_between) || is.null(beta_bernoulli_beta_between)) {
beta_bernoulli_beta_between <- -1.0
} else if(is.null(beta_bernoulli_alpha_between) || is.null(beta_bernoulli_beta_between)) {
stop("If you wish to specify different between and within cluster probabilites,
provide both beta_bernoulli_alpha_between and beta_bernoulli_beta_between,
otherwise leave both NULL.")
Expand All @@ -479,11 +486,12 @@ bgm = function(
edge_prior = model$edge_prior
inclusion_probability = model$inclusion_probability

#Check Gibbs input -----------------------------------------------------------
# Check Gibbs input -----------------------------------------------------------
check_positive_integer(iter, "iter")
check_non_negative_integer(warmup, "warmup")
if(warmup < 1e3)
if(warmup < 1e3) {
warning("The warmup parameter is set to a low value. This may lead to unreliable results. Reset to a minimum of 1000 iterations.")
}
warmup = max(warmup, 1e3) # Set minimum warmup to 1000 iterations

check_positive_integer(hmc_num_leapfrogs, "hmc_num_leapfrogs")
Expand All @@ -492,22 +500,27 @@ bgm = function(
check_positive_integer(nuts_max_depth, "nuts_max_depth")
nuts_max_depth = max(nuts_max_depth, 1) # Set minimum nuts_max_depth to 1

#Check na_action -------------------------------------------------------------
# Check na_action -------------------------------------------------------------
na_action_input = na_action
na_action = try(match.arg(na_action), silent = TRUE)
if(inherits(na_action, what = "try-error"))
stop(paste0("The na_action argument should equal listwise or impute, not ",
na_action_input,
"."))
if(inherits(na_action, what = "try-error")) {
stop(paste0(
"The na_action argument should equal listwise or impute, not ",
na_action_input,
"."
))
}

#Check display_progress ------------------------------------------------------
# Check display_progress ------------------------------------------------------
progress_type = progress_type_from_display_progress(display_progress)

#Format the data input -------------------------------------------------------
data = reformat_data(x = x,
na_action = na_action,
variable_bool = variable_bool,
baseline_category = baseline_category)
# Format the data input -------------------------------------------------------
data = reformat_data(
x = x,
na_action = na_action,
variable_bool = variable_bool,
baseline_category = baseline_category
)
x = data$x
num_categories = data$num_categories
missing_index = data$missing_index
Expand All @@ -520,44 +533,47 @@ bgm = function(

# Starting value of model matrix ---------------------------------------------
indicator = matrix(1,
nrow = num_variables,
ncol = num_variables)
nrow = num_variables,
ncol = num_variables
)


#Starting values of interactions and thresholds (posterior mode) -------------
# Starting values of interactions and thresholds (posterior mode) -------------
interactions = matrix(0, nrow = num_variables, ncol = num_variables)
thresholds = matrix(0, nrow = num_variables, ncol = max(num_categories))

#Precompute the number of observations per category for each variable --------
# Precompute the number of observations per category for each variable --------
counts_per_category = matrix(0,
nrow = max(num_categories) + 1,
ncol = num_variables)
nrow = max(num_categories) + 1,
ncol = num_variables
)
for(variable in 1:num_variables) {
for(category in 0:num_categories[variable]) {
counts_per_category[category + 1, variable] = sum(x[, variable] == category)
}
}

#Precompute the sufficient statistics for the two Blume-Capel parameters -----
# Precompute the sufficient statistics for the two Blume-Capel parameters -----
blume_capel_stats = matrix(0, nrow = 2, ncol = num_variables)
if(any(!variable_bool)) {
# Ordinal (variable_bool == TRUE) or Blume-Capel (variable_bool == FALSE)
bc_vars = which(!variable_bool)
for(i in bc_vars) {
blume_capel_stats[1, i] = sum(x[, i])
blume_capel_stats[2, i] = sum((x[, i] - baseline_category[i]) ^ 2)
blume_capel_stats[2, i] = sum((x[, i] - baseline_category[i])^2)
}
}
pairwise_stats = t(x) %*% x

# Index matrix used in the c++ functions ------------------------------------
interaction_index_matrix = matrix(0,
nrow = num_variables * (num_variables - 1) / 2,
ncol = 3)
nrow = num_variables * (num_variables - 1) / 2,
ncol = 3
)
cntr = 0
for(variable1 in 1:(num_variables - 1)) {
for(variable2 in (variable1 + 1):num_variables) {
cntr = cntr + 1
cntr = cntr + 1
interaction_index_matrix[cntr, 1] = cntr - 1
interaction_index_matrix[cntr, 2] = variable1 - 1
interaction_index_matrix[cntr, 3] = variable2 - 1
Expand All @@ -566,21 +582,21 @@ bgm = function(

pairwise_effect_indices = matrix(NA, nrow = num_variables, ncol = num_variables)
tel = 0
for (v1 in seq_len(num_variables - 1)) {
for (v2 in seq((v1 + 1), num_variables)) {
for(v1 in seq_len(num_variables - 1)) {
for(v2 in seq((v1 + 1), num_variables)) {
pairwise_effect_indices[v1, v2] = tel
pairwise_effect_indices[v2, v1] = tel
tel = tel + 1 # C++ starts at zero
tel = tel + 1 # C++ starts at zero
}
}

#Setting the seed
if (missing(seed) || is.null(seed)) {
# Setting the seed
if(missing(seed) || is.null(seed)) {
# Draw a random seed if none provided
seed = sample.int(.Machine$integer.max, 1)
}

if (!is.numeric(seed) || length(seed) != 1 || is.na(seed) || seed < 0) {
if(!is.numeric(seed) || length(seed) != 1 || is.na(seed) || seed < 0) {
stop("Argument 'seed' must be a single non-negative integer.")
}

Expand Down Expand Up @@ -612,13 +628,13 @@ bgm = function(


userInterrupt = any(vapply(out, FUN = `[[`, FUN.VALUE = logical(1L), "userInterrupt"))
if (userInterrupt) {
if(userInterrupt) {
warning("Stopped sampling after user interrupt, results are likely uninterpretable.")
# Try to prepare output, but catch any errors
output <- tryCatch(
prepare_output_bgm(
out = out, x = x, num_categories = num_categories, iter = iter,
data_columnnames = if (is.null(colnames(x))) paste0("Variable ", seq_len(ncol(x))) else colnames(x),
data_columnnames = if(is.null(colnames(x))) paste0("Variable ", seq_len(ncol(x))) else colnames(x),
is_ordinal_variable = variable_bool,
warmup = warmup, pairwise_scale = pairwise_scale,
main_alpha = main_alpha, main_beta = main_beta,
Expand Down Expand Up @@ -647,9 +663,9 @@ bgm = function(
}

# Main output handler in the wrapper function
output = prepare_output_bgm (
output = prepare_output_bgm(
out = out, x = x, num_categories = num_categories, iter = iter,
data_columnnames = if (is.null(colnames(x))) paste0("Variable ", seq_len(ncol(x))) else colnames(x),
data_columnnames = if(is.null(colnames(x))) paste0("Variable ", seq_len(ncol(x))) else colnames(x),
is_ordinal_variable = variable_bool,
warmup = warmup, pairwise_scale = pairwise_scale,
main_alpha = main_alpha, main_beta = main_beta,
Expand All @@ -669,7 +685,7 @@ bgm = function(
num_chains = chains
)

if (update_method == "nuts") {
if(update_method == "nuts") {
nuts_diag = summarize_nuts_diagnostics(out, nuts_max_depth = nuts_max_depth)
output$nuts_diag = nuts_diag
}
Expand All @@ -678,21 +694,23 @@ bgm = function(
# TODO: REMOVE after easybgm >= 0.2.2 is on CRAN
# Compatibility shim for easybgm <= 0.2.1
# -------------------------------------------------------------------
if ("easybgm" %in% loadedNamespaces()) {
if("easybgm" %in% loadedNamespaces()) {
ebgm_version <- utils::packageVersion("easybgm")
if (ebgm_version <= "0.2.1") {
warning("bgms is running in compatibility mode for easybgm (<= 0.2.1). ",
"This will be removed once easybgm >= 0.2.2 is on CRAN.")
if(ebgm_version <= "0.2.1") {
warning(
"bgms is running in compatibility mode for easybgm (<= 0.2.1). ",
"This will be removed once easybgm >= 0.2.2 is on CRAN."
)

# Add legacy variables to output
output$arguments$save <- TRUE
if (edge_selection) {
if(edge_selection) {
output$indicator <- extract_indicators(output)
}
output$interactions <- extract_pairwise_interactions(output)
output$thresholds <- extract_category_thresholds(output)
output$thresholds <- extract_category_thresholds(output)
}
}

return(output)
}
}
Loading
Loading