diff --git a/R/bgm.R b/R/bgm.R index dd4268d..cd42154 100644 --- a/R/bgm.R +++ b/R/bgm.R @@ -352,64 +352,65 @@ #' #' @export bgm = function( - x, - variable_type = "ordinal", - baseline_category, - iter = 1e3, - warmup = 1e3, - pairwise_scale = 2.5, - main_alpha = 0.5, - main_beta = 0.5, - edge_selection = TRUE, - edge_prior = c("Bernoulli", "Beta-Bernoulli", "Stochastic-Block"), - inclusion_probability = 0.5, - beta_bernoulli_alpha = 1, - beta_bernoulli_beta = 1, - beta_bernoulli_alpha_between = 1, - beta_bernoulli_beta_between = 1, - dirichlet_alpha = 1, - lambda = 1, - na_action = c("listwise", "impute"), - update_method = c("nuts", "adaptive-metropolis", "hamiltonian-mc"), - target_accept, - hmc_num_leapfrogs = 100, - nuts_max_depth = 10, - learn_mass_matrix = FALSE, - chains = 4, - cores = parallel::detectCores(), - display_progress = c("per-chain", "total", "none"), - seed = NULL, - interaction_scale, - burnin, - save, - threshold_alpha, - threshold_beta + x, + variable_type = "ordinal", + baseline_category, + iter = 1e3, + warmup = 1e3, + pairwise_scale = 2.5, + main_alpha = 0.5, + main_beta = 0.5, + edge_selection = TRUE, + edge_prior = c("Bernoulli", "Beta-Bernoulli", "Stochastic-Block"), + inclusion_probability = 0.5, + beta_bernoulli_alpha = 1, + beta_bernoulli_beta = 1, + beta_bernoulli_alpha_between = 1, + beta_bernoulli_beta_between = 1, + dirichlet_alpha = 1, + lambda = 1, + na_action = c("listwise", "impute"), + update_method = c("nuts", "adaptive-metropolis", "hamiltonian-mc"), + target_accept, + hmc_num_leapfrogs = 100, + nuts_max_depth = 10, + learn_mass_matrix = FALSE, + chains = 4, + cores = parallel::detectCores(), + display_progress = c("per-chain", "total", "none"), + seed = NULL, + interaction_scale, + burnin, + save, + threshold_alpha, + threshold_beta ) { - if (hasArg(interaction_scale)) { + if(hasArg(interaction_scale)) { lifecycle::deprecate_warn("0.1.6.0", "bgm(interaction_scale =)", "bgm(pairwise_scale =)") - if (!hasArg(pairwise_scale)) { + if(!hasArg(pairwise_scale)) { pairwise_scale = interaction_scale } } - if (hasArg(burnin)) { + if(hasArg(burnin)) { lifecycle::deprecate_warn("0.1.6.0", "bgm(burnin =)", "bgm(warmup =)") - if (!hasArg(warmup)) { + if(!hasArg(warmup)) { warmup = burnin } } - if (hasArg(save)) { + if(hasArg(save)) { lifecycle::deprecate_warn("0.1.6.0", "bgm(save =)") } - if (hasArg(threshold_alpha) || hasArg(threshold_beta)) { - lifecycle::deprecate_warn("0.1.6.0", - "bgm(threshold_alpha =, threshold_beta =)", - "bgm(main_alpha =, main_beta =)" + if(hasArg(threshold_alpha) || hasArg(threshold_beta)) { + lifecycle::deprecate_warn( + "0.1.6.0", + "bgm(threshold_alpha =, threshold_beta =)", + "bgm(main_alpha =, main_beta =)" ) - if (!hasArg(main_alpha)) main_alpha = threshold_alpha - if (!hasArg(main_beta)) main_beta = threshold_beta + if(!hasArg(main_alpha)) main_alpha = threshold_alpha + if(!hasArg(main_beta)) main_beta = threshold_beta } # Check update method @@ -430,39 +431,45 @@ bgm = function( } } - #Check data input ------------------------------------------------------------ - if(!inherits(x, what = "matrix") && !inherits(x, what = "data.frame")) + # Check data input ------------------------------------------------------------ + if(!inherits(x, what = "matrix") && !inherits(x, what = "data.frame")) { stop("The input x needs to be a matrix or dataframe.") - if(inherits(x, what = "data.frame")) + } + if(inherits(x, what = "data.frame")) { x = data.matrix(x) - if(ncol(x) < 2) + } + if(ncol(x) < 2) { stop("The matrix x should have more than one variable (columns).") - if(nrow(x) < 2) + } + if(nrow(x) < 2) { stop("The matrix x should have more than one observation (rows).") + } - #Check model input ----------------------------------------------------------- - model = check_model(x = x, - variable_type = variable_type, - baseline_category = baseline_category, - pairwise_scale = pairwise_scale, - main_alpha = main_alpha, - main_beta = main_beta, - edge_selection = edge_selection, - edge_prior = edge_prior, - inclusion_probability = inclusion_probability, - beta_bernoulli_alpha = beta_bernoulli_alpha, - beta_bernoulli_beta = beta_bernoulli_beta, - beta_bernoulli_alpha_between = beta_bernoulli_alpha_between, - beta_bernoulli_beta_between = beta_bernoulli_beta_between, - dirichlet_alpha = dirichlet_alpha, - lambda = lambda) + # Check model input ----------------------------------------------------------- + model = check_model( + x = x, + variable_type = variable_type, + baseline_category = baseline_category, + pairwise_scale = pairwise_scale, + main_alpha = main_alpha, + main_beta = main_beta, + edge_selection = edge_selection, + edge_prior = edge_prior, + inclusion_probability = inclusion_probability, + beta_bernoulli_alpha = beta_bernoulli_alpha, + beta_bernoulli_beta = beta_bernoulli_beta, + beta_bernoulli_alpha_between = beta_bernoulli_alpha_between, + beta_bernoulli_beta_between = beta_bernoulli_beta_between, + dirichlet_alpha = dirichlet_alpha, + lambda = lambda + ) # check hyperparameters input # If user left them NULL, pass -1 to C++ (means: ignore between prior) - if (is.null(beta_bernoulli_alpha_between) && is.null(beta_bernoulli_beta_between)) { + if(is.null(beta_bernoulli_alpha_between) && is.null(beta_bernoulli_beta_between)) { beta_bernoulli_alpha_between <- -1.0 - beta_bernoulli_beta_between <- -1.0 - } else if (is.null(beta_bernoulli_alpha_between) || is.null(beta_bernoulli_beta_between)) { + beta_bernoulli_beta_between <- -1.0 + } else if(is.null(beta_bernoulli_alpha_between) || is.null(beta_bernoulli_beta_between)) { stop("If you wish to specify different between and within cluster probabilites, provide both beta_bernoulli_alpha_between and beta_bernoulli_beta_between, otherwise leave both NULL.") @@ -479,11 +486,12 @@ bgm = function( edge_prior = model$edge_prior inclusion_probability = model$inclusion_probability - #Check Gibbs input ----------------------------------------------------------- + # Check Gibbs input ----------------------------------------------------------- check_positive_integer(iter, "iter") check_non_negative_integer(warmup, "warmup") - if(warmup < 1e3) + if(warmup < 1e3) { warning("The warmup parameter is set to a low value. This may lead to unreliable results. Reset to a minimum of 1000 iterations.") + } warmup = max(warmup, 1e3) # Set minimum warmup to 1000 iterations check_positive_integer(hmc_num_leapfrogs, "hmc_num_leapfrogs") @@ -492,22 +500,27 @@ bgm = function( check_positive_integer(nuts_max_depth, "nuts_max_depth") nuts_max_depth = max(nuts_max_depth, 1) # Set minimum nuts_max_depth to 1 - #Check na_action ------------------------------------------------------------- + # Check na_action ------------------------------------------------------------- na_action_input = na_action na_action = try(match.arg(na_action), silent = TRUE) - if(inherits(na_action, what = "try-error")) - stop(paste0("The na_action argument should equal listwise or impute, not ", - na_action_input, - ".")) + if(inherits(na_action, what = "try-error")) { + stop(paste0( + "The na_action argument should equal listwise or impute, not ", + na_action_input, + "." + )) + } - #Check display_progress ------------------------------------------------------ + # Check display_progress ------------------------------------------------------ progress_type = progress_type_from_display_progress(display_progress) - #Format the data input ------------------------------------------------------- - data = reformat_data(x = x, - na_action = na_action, - variable_bool = variable_bool, - baseline_category = baseline_category) + # Format the data input ------------------------------------------------------- + data = reformat_data( + x = x, + na_action = na_action, + variable_bool = variable_bool, + baseline_category = baseline_category + ) x = data$x num_categories = data$num_categories missing_index = data$missing_index @@ -520,44 +533,47 @@ bgm = function( # Starting value of model matrix --------------------------------------------- indicator = matrix(1, - nrow = num_variables, - ncol = num_variables) + nrow = num_variables, + ncol = num_variables + ) - #Starting values of interactions and thresholds (posterior mode) ------------- + # Starting values of interactions and thresholds (posterior mode) ------------- interactions = matrix(0, nrow = num_variables, ncol = num_variables) thresholds = matrix(0, nrow = num_variables, ncol = max(num_categories)) - #Precompute the number of observations per category for each variable -------- + # Precompute the number of observations per category for each variable -------- counts_per_category = matrix(0, - nrow = max(num_categories) + 1, - ncol = num_variables) + nrow = max(num_categories) + 1, + ncol = num_variables + ) for(variable in 1:num_variables) { for(category in 0:num_categories[variable]) { counts_per_category[category + 1, variable] = sum(x[, variable] == category) } } - #Precompute the sufficient statistics for the two Blume-Capel parameters ----- + # Precompute the sufficient statistics for the two Blume-Capel parameters ----- blume_capel_stats = matrix(0, nrow = 2, ncol = num_variables) if(any(!variable_bool)) { # Ordinal (variable_bool == TRUE) or Blume-Capel (variable_bool == FALSE) bc_vars = which(!variable_bool) for(i in bc_vars) { blume_capel_stats[1, i] = sum(x[, i]) - blume_capel_stats[2, i] = sum((x[, i] - baseline_category[i]) ^ 2) + blume_capel_stats[2, i] = sum((x[, i] - baseline_category[i])^2) } } pairwise_stats = t(x) %*% x # Index matrix used in the c++ functions ------------------------------------ interaction_index_matrix = matrix(0, - nrow = num_variables * (num_variables - 1) / 2, - ncol = 3) + nrow = num_variables * (num_variables - 1) / 2, + ncol = 3 + ) cntr = 0 for(variable1 in 1:(num_variables - 1)) { for(variable2 in (variable1 + 1):num_variables) { - cntr = cntr + 1 + cntr = cntr + 1 interaction_index_matrix[cntr, 1] = cntr - 1 interaction_index_matrix[cntr, 2] = variable1 - 1 interaction_index_matrix[cntr, 3] = variable2 - 1 @@ -566,21 +582,21 @@ bgm = function( pairwise_effect_indices = matrix(NA, nrow = num_variables, ncol = num_variables) tel = 0 - for (v1 in seq_len(num_variables - 1)) { - for (v2 in seq((v1 + 1), num_variables)) { + for(v1 in seq_len(num_variables - 1)) { + for(v2 in seq((v1 + 1), num_variables)) { pairwise_effect_indices[v1, v2] = tel pairwise_effect_indices[v2, v1] = tel - tel = tel + 1 # C++ starts at zero + tel = tel + 1 # C++ starts at zero } } - #Setting the seed - if (missing(seed) || is.null(seed)) { + # Setting the seed + if(missing(seed) || is.null(seed)) { # Draw a random seed if none provided seed = sample.int(.Machine$integer.max, 1) } - if (!is.numeric(seed) || length(seed) != 1 || is.na(seed) || seed < 0) { + if(!is.numeric(seed) || length(seed) != 1 || is.na(seed) || seed < 0) { stop("Argument 'seed' must be a single non-negative integer.") } @@ -612,13 +628,13 @@ bgm = function( userInterrupt = any(vapply(out, FUN = `[[`, FUN.VALUE = logical(1L), "userInterrupt")) - if (userInterrupt) { + if(userInterrupt) { warning("Stopped sampling after user interrupt, results are likely uninterpretable.") # Try to prepare output, but catch any errors output <- tryCatch( prepare_output_bgm( out = out, x = x, num_categories = num_categories, iter = iter, - data_columnnames = if (is.null(colnames(x))) paste0("Variable ", seq_len(ncol(x))) else colnames(x), + data_columnnames = if(is.null(colnames(x))) paste0("Variable ", seq_len(ncol(x))) else colnames(x), is_ordinal_variable = variable_bool, warmup = warmup, pairwise_scale = pairwise_scale, main_alpha = main_alpha, main_beta = main_beta, @@ -647,9 +663,9 @@ bgm = function( } # Main output handler in the wrapper function - output = prepare_output_bgm ( + output = prepare_output_bgm( out = out, x = x, num_categories = num_categories, iter = iter, - data_columnnames = if (is.null(colnames(x))) paste0("Variable ", seq_len(ncol(x))) else colnames(x), + data_columnnames = if(is.null(colnames(x))) paste0("Variable ", seq_len(ncol(x))) else colnames(x), is_ordinal_variable = variable_bool, warmup = warmup, pairwise_scale = pairwise_scale, main_alpha = main_alpha, main_beta = main_beta, @@ -669,7 +685,7 @@ bgm = function( num_chains = chains ) - if (update_method == "nuts") { + if(update_method == "nuts") { nuts_diag = summarize_nuts_diagnostics(out, nuts_max_depth = nuts_max_depth) output$nuts_diag = nuts_diag } @@ -678,21 +694,23 @@ bgm = function( # TODO: REMOVE after easybgm >= 0.2.2 is on CRAN # Compatibility shim for easybgm <= 0.2.1 # ------------------------------------------------------------------- - if ("easybgm" %in% loadedNamespaces()) { + if("easybgm" %in% loadedNamespaces()) { ebgm_version <- utils::packageVersion("easybgm") - if (ebgm_version <= "0.2.1") { - warning("bgms is running in compatibility mode for easybgm (<= 0.2.1). ", - "This will be removed once easybgm >= 0.2.2 is on CRAN.") + if(ebgm_version <= "0.2.1") { + warning( + "bgms is running in compatibility mode for easybgm (<= 0.2.1). ", + "This will be removed once easybgm >= 0.2.2 is on CRAN." + ) # Add legacy variables to output output$arguments$save <- TRUE - if (edge_selection) { + if(edge_selection) { output$indicator <- extract_indicators(output) } output$interactions <- extract_pairwise_interactions(output) - output$thresholds <- extract_category_thresholds(output) + output$thresholds <- extract_category_thresholds(output) } } return(output) -} \ No newline at end of file +} diff --git a/R/bgmCompare.R b/R/bgmCompare.R index 8b49724..6478ff9 100644 --- a/R/bgmCompare.R +++ b/R/bgmCompare.R @@ -176,127 +176,134 @@ #' #' @export bgmCompare = function( - x, - y, - group_indicator, - difference_selection = TRUE, - variable_type = "ordinal", - baseline_category, - difference_scale = 1, - difference_prior = c("Bernoulli", "Beta-Bernoulli"), - difference_probability = 0.5, - beta_bernoulli_alpha = 1, - beta_bernoulli_beta = 1, - pairwise_scale = 2.5, - main_alpha = 0.5, - main_beta = 0.5, - iter = 1e3, - warmup = 1e3, - na_action = c("listwise", "impute"), - update_method = c("nuts", "adaptive-metropolis", "hamiltonian-mc"), - target_accept, - hmc_num_leapfrogs = 100, - nuts_max_depth = 10, - learn_mass_matrix = FALSE, - chains = 4, - cores = parallel::detectCores(), - display_progress = c("per-chain", "total", "none"), - seed = NULL, - main_difference_model, - reference_category, - main_difference_scale, - pairwise_difference_scale, - pairwise_difference_prior, - main_difference_prior, - pairwise_difference_probability, - main_difference_probability, - pairwise_beta_bernoulli_alpha, - pairwise_beta_bernoulli_beta, - main_beta_bernoulli_alpha, - main_beta_bernoulli_beta, - interaction_scale, - threshold_alpha, - threshold_beta, - burnin, - save + x, + y, + group_indicator, + difference_selection = TRUE, + variable_type = "ordinal", + baseline_category, + difference_scale = 1, + difference_prior = c("Bernoulli", "Beta-Bernoulli"), + difference_probability = 0.5, + beta_bernoulli_alpha = 1, + beta_bernoulli_beta = 1, + pairwise_scale = 2.5, + main_alpha = 0.5, + main_beta = 0.5, + iter = 1e3, + warmup = 1e3, + na_action = c("listwise", "impute"), + update_method = c("nuts", "adaptive-metropolis", "hamiltonian-mc"), + target_accept, + hmc_num_leapfrogs = 100, + nuts_max_depth = 10, + learn_mass_matrix = FALSE, + chains = 4, + cores = parallel::detectCores(), + display_progress = c("per-chain", "total", "none"), + seed = NULL, + main_difference_model, + reference_category, + main_difference_scale, + pairwise_difference_scale, + pairwise_difference_prior, + main_difference_prior, + pairwise_difference_probability, + main_difference_probability, + pairwise_beta_bernoulli_alpha, + pairwise_beta_bernoulli_beta, + main_beta_bernoulli_alpha, + main_beta_bernoulli_beta, + interaction_scale, + threshold_alpha, + threshold_beta, + burnin, + save ) { - if (hasArg(main_difference_model)) { + if(hasArg(main_difference_model)) { lifecycle::deprecate_warn("0.1.6.0", "bgmCompare(main_difference_model =)") } - if (hasArg(reference_category)) { + if(hasArg(reference_category)) { lifecycle::deprecate_warn("0.1.6.0", "bgmCompare(reference_category =)", "bgmCompare(baseline_category =)") - if (!hasArg(baseline_category)) baseline_category = reference_category + if(!hasArg(baseline_category)) baseline_category = reference_category } - if (hasArg(pairwise_difference_scale) || hasArg(main_difference_scale)) { - lifecycle::deprecate_warn("0.1.6.0", "bgmCompare(pairwise_difference_scale =, main_difference_scale =)", - "bgmCompare(difference_scale =)") - if (!hasArg(difference_scale)) { - difference_scale = if (!missing(pairwise_difference_scale)) pairwise_difference_scale else main_difference_scale + if(hasArg(pairwise_difference_scale) || hasArg(main_difference_scale)) { + lifecycle::deprecate_warn( + "0.1.6.0", "bgmCompare(pairwise_difference_scale =, main_difference_scale =)", + "bgmCompare(difference_scale =)" + ) + if(!hasArg(difference_scale)) { + difference_scale = if(!missing(pairwise_difference_scale)) pairwise_difference_scale else main_difference_scale } } - if (hasArg(pairwise_difference_prior) || hasArg(main_difference_prior)) { - lifecycle::deprecate_warn("0.1.6.0", - "bgmCompare(pairwise_difference_prior =, main_difference_prior =)", - "bgmCompare(difference_prior =)" + if(hasArg(pairwise_difference_prior) || hasArg(main_difference_prior)) { + lifecycle::deprecate_warn( + "0.1.6.0", + "bgmCompare(pairwise_difference_prior =, main_difference_prior =)", + "bgmCompare(difference_prior =)" ) - if (!hasArg(difference_prior)) { - difference_prior = if (!missing(pairwise_difference_prior)) pairwise_difference_prior else main_difference_prior + if(!hasArg(difference_prior)) { + difference_prior = if(!missing(pairwise_difference_prior)) pairwise_difference_prior else main_difference_prior } } - if (hasArg(pairwise_difference_probability) || hasArg(main_difference_probability)) { - lifecycle::deprecate_warn("0.1.6.0", - "bgmCompare(pairwise_difference_probability =, main_difference_probability =)", - "bgmCompare(difference_probability =)" + if(hasArg(pairwise_difference_probability) || hasArg(main_difference_probability)) { + lifecycle::deprecate_warn( + "0.1.6.0", + "bgmCompare(pairwise_difference_probability =, main_difference_probability =)", + "bgmCompare(difference_probability =)" ) - if (!hasArg(difference_probability)) { - difference_probability = if (!missing(pairwise_difference_probability)) pairwise_difference_probability else main_difference_probability + if(!hasArg(difference_probability)) { + difference_probability = if(!missing(pairwise_difference_probability)) pairwise_difference_probability else main_difference_probability } } - if (hasArg(pairwise_beta_bernoulli_alpha) || hasArg(main_beta_bernoulli_alpha)) { - lifecycle::deprecate_warn("0.1.6.0", - "bgmCompare(pairwise_beta_bernoulli_alpha =, main_beta_bernoulli_alpha =)", - "bgmCompare(beta_bernoulli_alpha =)" + if(hasArg(pairwise_beta_bernoulli_alpha) || hasArg(main_beta_bernoulli_alpha)) { + lifecycle::deprecate_warn( + "0.1.6.0", + "bgmCompare(pairwise_beta_bernoulli_alpha =, main_beta_bernoulli_alpha =)", + "bgmCompare(beta_bernoulli_alpha =)" ) - if (!hasArg(beta_bernoulli_alpha)) { - beta_bernoulli_alpha = if (!missing(pairwise_beta_bernoulli_alpha)) pairwise_beta_bernoulli_alpha else main_beta_bernoulli_alpha + if(!hasArg(beta_bernoulli_alpha)) { + beta_bernoulli_alpha = if(!missing(pairwise_beta_bernoulli_alpha)) pairwise_beta_bernoulli_alpha else main_beta_bernoulli_alpha } } - if (hasArg(pairwise_beta_bernoulli_beta) || hasArg(main_beta_bernoulli_beta)) { - lifecycle::deprecate_warn("0.1.6.0", - "bgmCompare(pairwise_beta_bernoulli_beta =, main_beta_bernoulli_beta =)", - "bgmCompare(beta_bernoulli_beta =)" + if(hasArg(pairwise_beta_bernoulli_beta) || hasArg(main_beta_bernoulli_beta)) { + lifecycle::deprecate_warn( + "0.1.6.0", + "bgmCompare(pairwise_beta_bernoulli_beta =, main_beta_bernoulli_beta =)", + "bgmCompare(beta_bernoulli_beta =)" ) - if (!hasArg(beta_bernoulli_beta)) { - beta_bernoulli_beta = if (!missing(pairwise_beta_bernoulli_beta)) pairwise_beta_bernoulli_beta else main_beta_bernoulli_beta + if(!hasArg(beta_bernoulli_beta)) { + beta_bernoulli_beta = if(!missing(pairwise_beta_bernoulli_beta)) pairwise_beta_bernoulli_beta else main_beta_bernoulli_beta } } - if (hasArg(interaction_scale)) { + if(hasArg(interaction_scale)) { lifecycle::deprecate_warn("0.1.6.0", "bgmCompare(interaction_scale =)", "bgmCompare(pairwise_scale =)") - if (!hasArg(pairwise_scale)) pairwise_scale = interaction_scale + if(!hasArg(pairwise_scale)) pairwise_scale = interaction_scale } - if (hasArg(threshold_alpha) || hasArg(threshold_beta)) { - lifecycle::deprecate_warn("0.1.6.0", - "bgmCompare(threshold_alpha =, threshold_beta =)", - "bgmCompare(main_alpha =, main_beta =)" # = double-check if these are still part of bgmCompare + if(hasArg(threshold_alpha) || hasArg(threshold_beta)) { + lifecycle::deprecate_warn( + "0.1.6.0", + "bgmCompare(threshold_alpha =, threshold_beta =)", + "bgmCompare(main_alpha =, main_beta =)" # = double-check if these are still part of bgmCompare ) - if (!hasArg(main_alpha)) main_alpha = threshold_alpha - if (!hasArg(main_beta)) main_beta = threshold_beta + if(!hasArg(main_alpha)) main_alpha = threshold_alpha + if(!hasArg(main_beta)) main_beta = threshold_beta } - if (hasArg(burnin)) { + if(hasArg(burnin)) { lifecycle::deprecate_warn("0.1.6.0", "bgmCompare(burnin =)", "bgmCompare(warmup =)") - if (!hasArg(warmup)) warmup = burnin + if(!hasArg(warmup)) warmup = burnin } - if (hasArg(save)) { + if(hasArg(save)) { lifecycle::deprecate_warn("0.1.6.0", "bgmCompare(save =)") } @@ -320,27 +327,32 @@ bgmCompare = function( # Check and preprocess data x = data_check(x, "x") - if (hasArg(y)) { + if(hasArg(y)) { y = data_check(y, "y") - if (ncol(x) != ncol(y)) stop("x and y must have the same number of columns.") + if(ncol(x) != ncol(y)) stop("x and y must have the same number of columns.") } - if(!hasArg(y) & !hasArg(group_indicator)) - stop(paste0("For multi-group designs, the bgmCompare function requires input for\n", - "either y (group 2 data) or group_indicator (group indicator).")) + if(!hasArg(y) & !hasArg(group_indicator)) { + stop(paste0( + "For multi-group designs, the bgmCompare function requires input for\n", + "either y (group 2 data) or group_indicator (group indicator)." + )) + } # Validate group indicators - if (!hasArg(y) && hasArg(group_indicator)) { + if(!hasArg(y) && hasArg(group_indicator)) { group_indicator = as.vector(group_indicator) - if (anyNA(group_indicator)) stop("group_indicator cannot contain missing values.") - if (length(group_indicator) != nrow(x)) stop("Length of group_indicator must match number of rows in x.") + if(anyNA(group_indicator)) stop("group_indicator cannot contain missing values.") + if(length(group_indicator) != nrow(x)) stop("Length of group_indicator must match number of rows in x.") } # Model and preprocessing - if(!hasArg(y)) + if(!hasArg(y)) { y = NULL - if(!hasArg(group_indicator)) + } + if(!hasArg(group_indicator)) { group_indicator = NULL + } model = check_compare_model( x = x, y = y, group_indicator = group_indicator, difference_selection = difference_selection, @@ -366,7 +378,7 @@ bgmCompare = function( # Check na_action na_action_input = na_action na_action = try(match.arg(na_action), silent = TRUE) - if (inherits(na_action, "try-error")) { + if(inherits(na_action, "try-error")) { stop(sprintf("Invalid value for `na_action`. Expected 'listwise' or 'impute', got: %s", na_action_input)) } @@ -414,7 +426,7 @@ bgmCompare = function( counter = 0 for(variable1 in 1:(num_variables - 1)) { for(variable2 in (variable1 + 1):num_variables) { - counter = counter + 1 + counter = counter + 1 Index[counter, ] = c(counter, variable1 - 1, variable2 - 1) } } @@ -422,13 +434,13 @@ bgmCompare = function( # Gibbs sampling # Prepare indices for main and pairwise effects main_effect_indices = matrix(NA, nrow = num_variables, ncol = 2) - for (variable in seq_len(num_variables)) { - if (variable > 1) { + for(variable in seq_len(num_variables)) { + if(variable > 1) { main_effect_indices[variable, 1] = 1 + main_effect_indices[variable - 1, 2] } else { - main_effect_indices[variable, 1] = 0 # C++ starts at zero + main_effect_indices[variable, 1] = 0 # C++ starts at zero } - if (ordinal_variable[variable]) { + if(ordinal_variable[variable]) { main_effect_indices[variable, 2] = main_effect_indices[variable, 1] + num_categories[variable] - 1 } else { main_effect_indices[variable, 2] = main_effect_indices[variable, 1] + 1 @@ -437,11 +449,11 @@ bgmCompare = function( pairwise_effect_indices = matrix(NA, nrow = num_variables, ncol = num_variables) tel = 0 - for (v1 in seq_len(num_variables - 1)) { - for (v2 in seq((v1 + 1), num_variables)) { + for(v1 in seq_len(num_variables - 1)) { + for(v2 in seq((v1 + 1), num_variables)) { pairwise_effect_indices[v1, v2] = tel pairwise_effect_indices[v2, v1] = tel - tel = tel + 1 # C++ starts at zero + tel = tel + 1 # C++ starts at zero } } @@ -452,27 +464,27 @@ bgmCompare = function( # Align observations with sorted group observations = x sorted_group = sort(group) - for (g in unique(group)) { + for(g in unique(group)) { observations[which(sorted_group == g), ] = x[which(group == g), ] - group_indices[g, 1] = min(which(sorted_group == g)) - 1 # C++ starts at zero - group_indices[g, 2] = max(which(sorted_group == g)) - 1 # C++ starts at zero + group_indices[g, 1] = min(which(sorted_group == g)) - 1 # C++ starts at zero + group_indices[g, 2] = max(which(sorted_group == g)) - 1 # C++ starts at zero } # Compute projection matrix for group differences one = matrix(1, nrow = num_groups, ncol = num_groups) V = diag(num_groups) - one / num_groups projection = eigen(V)$vectors[, -num_groups] - if (num_groups == 2) { + if(num_groups == 2) { projection = matrix(projection, ncol = 1) / sqrt(2) } - #Setting the seed - if (missing(seed) || is.null(seed)) { + # Setting the seed + if(missing(seed) || is.null(seed)) { # Draw a random seed if none provided seed = sample.int(.Machine$integer.max, 1) } - if (!is.numeric(seed) || length(seed) != 1 || is.na(seed) || seed < 0) { + if(!is.numeric(seed) || length(seed) != 1 || is.na(seed) || seed < 0) { stop("Argument 'seed' must be a single non-negative integer.") } @@ -515,7 +527,7 @@ bgmCompare = function( ) userInterrupt = any(vapply(out, FUN = `[[`, FUN.VALUE = logical(1L), "userInterrupt")) - if (userInterrupt) { + if(userInterrupt) { warning("Stopped sampling after user interrupt, results are likely uninterpretable.") output <- tryCatch( prepare_output_bgmCompare( @@ -529,7 +541,7 @@ bgmCompare = function( warmup = warmup, main_effect_indices = main_effect_indices, pairwise_effect_indices = pairwise_effect_indices, - data_columnnames = if (is.null(colnames(x))) paste0("Variable ", seq_len(ncol(x))) else colnames(x), + data_columnnames = if(is.null(colnames(x))) paste0("Variable ", seq_len(ncol(x))) else colnames(x), difference_selection = difference_selection, difference_prior = difference_prior, difference_selection_alpha = beta_bernoulli_alpha, @@ -567,7 +579,7 @@ bgmCompare = function( warmup = warmup, main_effect_indices = main_effect_indices, pairwise_effect_indices = pairwise_effect_indices, - data_columnnames = if (is.null(colnames(x))) paste0("Variable ", seq_len(ncol(x))) else colnames(x), + data_columnnames = if(is.null(colnames(x))) paste0("Variable ", seq_len(ncol(x))) else colnames(x), difference_selection = difference_selection, difference_prior = difference_prior, difference_selection_alpha = beta_bernoulli_alpha, @@ -583,10 +595,10 @@ bgmCompare = function( num_chains = chains, projection = projection ) - if (update_method == "nuts") { + if(update_method == "nuts") { nuts_diag = summarize_nuts_diagnostics(out, nuts_max_depth = nuts_max_depth) output$nuts_diag = nuts_diag } return(output) -} \ No newline at end of file +} diff --git a/R/bgmcompare-methods.r b/R/bgmcompare-methods.r index 9a665b2..b2abe9c 100644 --- a/R/bgmcompare-methods.r +++ b/R/bgmcompare-methods.r @@ -10,9 +10,8 @@ print.bgmCompare = function(x, ...) { arguments = extract_arguments(x) # Model type - if (isTRUE(arguments$difference_selection)) { - prior_msg = switch( - as.character(arguments$difference_prior), + if(isTRUE(arguments$difference_selection)) { + prior_msg = switch(as.character(arguments$difference_prior), "Bernoulli" = "Bayesian Difference Selection (Bernoulli prior on inclusion)", "Beta-Bernoulli" = "Bayesian Difference Selection (Beta-Bernoulli prior on inclusion)", "Bayesian Difference Selection" @@ -24,12 +23,12 @@ print.bgmCompare = function(x, ...) { # Dataset info cat(paste0(" Number of variables: ", arguments$num_variables, "\n")) - if (!is.null(arguments$num_groups)) { + if(!is.null(arguments$num_groups)) { cat(paste0(" Number of groups: ", arguments$num_groups, "\n")) } - if (!is.null(arguments$num_cases)) { + if(!is.null(arguments$num_cases)) { # In our prepare_output_bgmCompare() we stored total cases in num_cases. - if (isTRUE(arguments$na_impute)) { + if(isTRUE(arguments$na_impute)) { cat(paste0(" Number of cases: ", arguments$num_cases, " (missings imputed)\n")) } else { cat(paste0(" Number of cases: ", arguments$num_cases, "\n")) @@ -37,7 +36,7 @@ print.bgmCompare = function(x, ...) { } # Iterations and chains - if (!is.null(arguments$num_chains)) { + if(!is.null(arguments$num_chains)) { total_iter = arguments$iter * arguments$num_chains cat(paste0(" Number of post-burnin MCMC iterations: ", total_iter, "\n")) cat(paste0(" Number of MCMC chains: ", arguments$num_chains, "\n")) @@ -50,7 +49,6 @@ print.bgmCompare = function(x, ...) { } - #' @name summary.bgmCompare #' @title Summary method for `bgmCompare` objects #' @@ -64,23 +62,22 @@ print.bgmCompare = function(x, ...) { summary.bgmCompare = function(object, ...) { arguments = extract_arguments(object) - if (!is.null(object$posterior_summary_main_baseline) && - !is.null(object$posterior_summary_pairwise_baseline)) { - + if(!is.null(object$posterior_summary_main_baseline) && + !is.null(object$posterior_summary_pairwise_baseline)) { out = list( main = object$posterior_summary_main_baseline, pairwise = object$posterior_summary_pairwise_baseline ) - if (!is.null(object$posterior_summary_indicator)) { + if(!is.null(object$posterior_summary_indicator)) { out$indicator = object$posterior_summary_indicator } - if (!is.null(object$posterior_summary_main_differences)) { + if(!is.null(object$posterior_summary_main_differences)) { out$main_diff = object$posterior_summary_main_differences } - if (!is.null(object$posterior_summary_pairwise_differences)) { + if(!is.null(object$posterior_summary_pairwise_differences)) { out$pairwise_diff = object$posterior_summary_pairwise_differences } @@ -89,9 +86,11 @@ summary.bgmCompare = function(object, ...) { return(out) } - message("No summary statistics available for this model object.\n", - "Try fitting the model again using the latest bgms version,\n", - "or use the `easybgm` package for diagnostic summaries and plotting.") + message( + "No summary statistics available for this model object.\n", + "Try fitting the model again using the latest bgms version,\n", + "or use the `easybgm` package for diagnostic summaries and plotting." + ) invisible(NULL) } @@ -102,35 +101,37 @@ print.summary.bgmCompare = function(x, digits = 3, ...) { print_df = function(df, digits) { df2 = df - if (ncol(df2) > 1) { - df2[ , -1] = lapply(df2[ , -1, drop = FALSE], round, digits = digits) + if(ncol(df2) > 1) { + df2[, -1] = lapply(df2[, -1, drop = FALSE], round, digits = digits) } print(head(df2, 6)) } - if (!is.null(x$main)) { + if(!is.null(x$main)) { cat("Category thresholds:\n") print_df(x$main, digits) - if (nrow(x$main) > 6) + if(nrow(x$main) > 6) { cat("... (use `summary(fit)$main` to see full output)\n") + } cat("\n") } - if (!is.null(x$pairwise)) { + if(!is.null(x$pairwise)) { cat("Pairwise interactions:\n") print_df(x$pairwise, digits) - if (nrow(x$pairwise) > 6) + if(nrow(x$pairwise) > 6) { cat("... (use `summary(fit)$pairwise` to see full output)\n") + } cat("\n") } - if (!is.null(x$indicator)) { + if(!is.null(x$indicator)) { cat("Inclusion probabilities:\n") ind <- head(x$indicator, 6) # round only numeric columns ind[] <- lapply(ind, function(col) { - if (is.numeric(col)) { + if(is.numeric(col)) { round(col, digits) } else { col @@ -143,29 +144,33 @@ print.summary.bgmCompare = function(x, digits = 3, ...) { }) print(ind, row.names = FALSE) - if (nrow(x$indicator) > 6) + if(nrow(x$indicator) > 6) { cat("... (use `summary(fit)$indicator` to see full output)\n") + } cat("Note: NA values are suppressed in the print table. They occur when an indicator\n") cat("was constant (all 0 or all 1) across all iterations, so sd/mcse/n_eff/Rhat\n") cat("are undefined; `summary(fit)$indicator` still contains the NA values.\n\n") } - if (!is.null(x$main_diff)) { + if(!is.null(x$main_diff)) { cat("Group differences (main effects):\n") maind <- head(x$main_diff, 6) # Only round numeric columns is_num <- vapply(maind, is.numeric, logical(1)) - maind[is_num] <- lapply(maind[is_num], - function(col) ifelse(is.na(col), "", round(col, digits))) + maind[is_num] <- lapply( + maind[is_num], + function(col) ifelse(is.na(col), "", round(col, digits)) + ) print(maind, row.names = FALSE) - if (nrow(x$main_diff) > 6) + if(nrow(x$main_diff) > 6) { cat("... (use `summary(fit)$main_diff` to see full output)\n") + } - if (!is.null(x$indicator)) { + if(!is.null(x$indicator)) { cat("Note: NA values are suppressed in the print table. They occur here when an\n") cat("indicator was zero across all iterations, so mcse/n_eff/Rhat are undefined;\n") cat("`summary(fit)$main_diff` still contains the NA values.\n") @@ -173,22 +178,25 @@ print.summary.bgmCompare = function(x, digits = 3, ...) { cat("\n") } - if (!is.null(x$pairwise_diff)) { + if(!is.null(x$pairwise_diff)) { cat("Group differences (pairwise effects):\n") pairwised <- head(x$pairwise_diff, 6) # Only round numeric columns is_num <- vapply(pairwised, is.numeric, logical(1)) - pairwised[is_num] <- lapply(pairwised[is_num], - function(col) ifelse(is.na(col), "", round(col, digits))) + pairwised[is_num] <- lapply( + pairwised[is_num], + function(col) ifelse(is.na(col), "", round(col, digits)) + ) print(pairwised, row.names = FALSE) - if (nrow(x$pairwise_diff) > 6) + if(nrow(x$pairwise_diff) > 6) { cat("... (use `summary(fit)$pairwise_diff` to see full output)\n") + } - if (!is.null(x$indicator)) { + if(!is.null(x$indicator)) { cat("Note: NA values are suppressed in the print table. They occur here when an\n") cat("indicator was zero across all iterations, so mcse/n_eff/Rhat are undefined;\n") cat("`summary(fit)$pairwise_diff` still contains the NA values.\n") @@ -201,7 +209,6 @@ print.summary.bgmCompare = function(x, digits = 3, ...) { } - #' @title Extract Coefficients from a bgmCompare Object #' @name coef.bgmCompare #' @description Returns posterior means for raw parameters (baseline + differences) @@ -227,26 +234,28 @@ print.summary.bgmCompare = function(x, digits = 3, ...) { coef.bgmCompare <- function(object, ...) { args <- extract_arguments(object) - var_names <- args$data_columnnames + var_names <- args$data_columnnames num_categories <- as.integer(args$num_categories) - is_ordinal <- as.logical(args$is_ordinal_variable) - num_groups <- as.integer(args$num_groups) - num_variables <- as.integer(args$num_variables) - projection <- args$projection # [num_groups x (num_groups-1)] + is_ordinal <- as.logical(args$is_ordinal_variable) + num_groups <- as.integer(args$num_groups) + num_variables <- as.integer(args$num_variables) + projection <- args$projection # [num_groups x (num_groups-1)] # ---- helper: combine chains into [iter, chain, param], robust to vectors/1-col to_array3d <- function(xlist) { - if (is.null(xlist)) return(NULL) + if(is.null(xlist)) { + return(NULL) + } stopifnot(length(xlist) >= 1) mats <- lapply(xlist, function(x) { m <- as.matrix(x) - if (is.null(dim(m))) m <- matrix(m, ncol = 1L) + if(is.null(dim(m))) m <- matrix(m, ncol = 1L) m }) - niter <- nrow(mats[[1]]) + niter <- nrow(mats[[1]]) nparam <- ncol(mats[[1]]) arr <- array(NA_real_, dim = c(niter, length(mats), nparam)) - for (c in seq_along(mats)) arr[, c, ] <- mats[[c]] + for(c in seq_along(mats)) arr[, c, ] <- mats[[c]] arr } @@ -263,20 +272,22 @@ coef.bgmCompare <- function(object, ...) { # row names in sampler row order rownames(main_mat) <- unlist(lapply(seq_len(num_variables), function(v) { - if (is_ordinal[v]) { + if(is_ordinal[v]) { paste0(var_names[v], "(c", seq_len(num_categories[v]), ")") } else { - c(paste0(var_names[v], "(linear)"), - paste0(var_names[v], "(quadratic)")) + c( + paste0(var_names[v], "(linear)"), + paste0(var_names[v], "(quadratic)") + ) } })) colnames(main_mat) <- c("baseline", paste0("diff", seq_len(num_groups - 1L))) # group-specific main effects: baseline + P %*% diffs main_effects_groups <- matrix(NA_real_, nrow = num_main, ncol = num_groups) - for (r in seq_len(num_main)) { + for(r in seq_len(num_main)) { baseline <- main_mat[r, 1] - diffs <- main_mat[r, -1, drop = TRUE] + diffs <- main_mat[r, -1, drop = TRUE] main_effects_groups[r, ] <- baseline + as.vector(projection %*% diffs) } rownames(main_effects_groups) <- rownames(main_mat) @@ -295,9 +306,9 @@ coef.bgmCompare <- function(object, ...) { # row names in sampler row order (upper-tri i= 2L) { - for (i in 1L:(num_variables - 1L)) { - for (j in (i + 1L):num_variables) { + if(num_variables >= 2L) { + for(i in 1L:(num_variables - 1L)) { + for(j in (i + 1L):num_variables) { pair_names <- c(pair_names, paste0(var_names[i], "-", var_names[j])) } } @@ -307,9 +318,9 @@ coef.bgmCompare <- function(object, ...) { # group-specific pairwise effects pairwise_effects_groups <- matrix(NA_real_, nrow = num_pair, ncol = num_groups) - for (r in seq_len(num_pair)) { + for(r in seq_len(num_pair)) { baseline <- pairwise_mat[r, 1] - diffs <- pairwise_mat[r, -1, drop = TRUE] + diffs <- pairwise_mat[r, -1, drop = TRUE] pairwise_effects_groups[r, ] <- baseline + as.vector(projection %*% diffs) } rownames(pairwise_effects_groups) <- rownames(pairwise_mat) @@ -319,7 +330,7 @@ coef.bgmCompare <- function(object, ...) { # ---- indicators (present only if selection was used) ---- indicators <- NULL array3d_ind <- to_array3d(object$raw_samples$indicator) - if (!is.null(array3d_ind)) { + if(!is.null(array3d_ind)) { mean_ind <- apply(array3d_ind, 3, mean) # reconstruct VxV matrix using the sampler’s interleaved order: @@ -327,15 +338,19 @@ coef.bgmCompare <- function(object, ...) { V <- num_variables stopifnot(length(mean_ind) == V * (V + 1L) / 2L) - ind_mat <- matrix(0, nrow = V, ncol = V, - dimnames = list(var_names, var_names)) + ind_mat <- matrix(0, + nrow = V, ncol = V, + dimnames = list(var_names, var_names) + ) pos <- 1L - for (i in seq_len(V)) { + for(i in seq_len(V)) { # diagonal (main indicator) - ind_mat[i, i] <- mean_ind[pos]; pos <- pos + 1L - if (i < V) { - for (j in (i + 1L):V) { - val <- mean_ind[pos]; pos <- pos + 1L + ind_mat[i, i] <- mean_ind[pos] + pos <- pos + 1L + if(i < V) { + for(j in (i + 1L):V) { + val <- mean_ind[pos] + pos <- pos + 1L ind_mat[i, j] <- val ind_mat[j, i] <- val } diff --git a/R/bgms-methods.R b/R/bgms-methods.R index 554f318..86db89d 100644 --- a/R/bgms-methods.R +++ b/R/bgms-methods.R @@ -11,12 +11,12 @@ print.bgms <- function(x, ...) { arguments <- extract_arguments(x) # Model type - if (isTRUE(arguments$edge_selection)) { + if(isTRUE(arguments$edge_selection)) { prior_msg <- switch(arguments$edge_prior, - "Bernoulli" = "Bayesian Edge Selection using a Bernoulli prior on edge inclusion", - "Beta-Bernoulli" = "Bayesian Edge Selection using a Beta-Bernoulli prior on edge inclusion", - "Stochastic-Block" = "Bayesian Edge Selection using a Stochastic Block prior on edge inclusion", - "Bayesian Edge Selection" + "Bernoulli" = "Bayesian Edge Selection using a Bernoulli prior on edge inclusion", + "Beta-Bernoulli" = "Bayesian Edge Selection using a Beta-Bernoulli prior on edge inclusion", + "Stochastic-Block" = "Bayesian Edge Selection using a Stochastic Block prior on edge inclusion", + "Bayesian Edge Selection" ) cat(prior_msg, "\n") } else { @@ -25,14 +25,14 @@ print.bgms <- function(x, ...) { # Dataset info cat(paste0(" Number of variables: ", arguments$num_variables, "\n")) - if (isTRUE(arguments$na_impute)) { + if(isTRUE(arguments$na_impute)) { cat(paste0(" Number of cases: ", arguments$num_cases, " (missings imputed)\n")) } else { cat(paste0(" Number of cases: ", arguments$num_cases, "\n")) } # Iterations and chains - if (!is.null(arguments$num_chains)) { + if(!is.null(arguments$num_chains)) { total_iter <- arguments$iter * arguments$num_chains cat(paste0(" Number of post-burnin MCMC iterations: ", total_iter, "\n")) cat(paste0(" Number of MCMC chains: ", arguments$num_chains, "\n")) @@ -45,7 +45,6 @@ print.bgms <- function(x, ...) { } - #' @name summary.bgms #' @title Summary method for `bgms` objects #' @@ -59,17 +58,17 @@ print.bgms <- function(x, ...) { summary.bgms <- function(object, ...) { arguments <- extract_arguments(object) - if (!is.null(object$posterior_summary_main) && !is.null(object$posterior_summary_pairwise)) { + if(!is.null(object$posterior_summary_main) && !is.null(object$posterior_summary_pairwise)) { out <- list( main = object$posterior_summary_main, pairwise = object$posterior_summary_pairwise ) - if (!is.null(object$posterior_summary_indicator)) { + if(!is.null(object$posterior_summary_indicator)) { out$indicator <- object$posterior_summary_indicator } - if (!is.null(object$posterior_summary_pairwise_allocations)) { + if(!is.null(object$posterior_summary_pairwise_allocations)) { out$allocations <- object$posterior_summary_pairwise_allocations out$mean_allocations <- object$posterior_mean_allocations out$mode_allocations <- object$posterior_mode_allocations @@ -80,33 +79,34 @@ summary.bgms <- function(object, ...) { return(out) } - message("No summary statistics available for this model object.\n", - "Try fitting the model again using the latest bgms version,\n", - "or use the `easybgm` package for diagnostic summaries and plotting.") + message( + "No summary statistics available for this model object.\n", + "Try fitting the model again using the latest bgms version,\n", + "or use the `easybgm` package for diagnostic summaries and plotting." + ) invisible(NULL) } - #' @export print.summary.bgms <- function(x, digits = 3, ...) { cat("Posterior summaries from Bayesian estimation:\n\n") - if (!is.null(x$main)) { + if(!is.null(x$main)) { cat("Category thresholds:\n") print(round(head(x$main, 6), digits = digits)) - if (nrow(x$main) > 6) cat("... (use `summary(fit)$main` to see full output)\n") + if(nrow(x$main) > 6) cat("... (use `summary(fit)$main` to see full output)\n") cat("\n") } - if (!is.null(x$pairwise)) { + if(!is.null(x$pairwise)) { cat("Pairwise interactions:\n") pair <- head(x$pairwise, 6) pair[] <- lapply(pair, function(col) ifelse(is.na(col), "", round(col, digits))) print(pair) - #print(round(head(x$pairwise, 6), digits = digits)) - if (nrow(x$pairwise) > 6) cat("... (use `summary(fit)$pairwise` to see full output)\n") - if (!is.null(x$indicator)) { + # print(round(head(x$pairwise, 6), digits = digits)) + if(nrow(x$pairwise) > 6) cat("... (use `summary(fit)$pairwise` to see full output)\n") + if(!is.null(x$indicator)) { cat("Note: NA values are suppressed in the print table. They occur here when an \n") cat("indicator was zero across all iterations, so mcse/n_eff/Rhat are undefined;\n") cat("`summary(fit)$pairwise` still contains the NA values.\n") @@ -114,25 +114,25 @@ print.summary.bgms <- function(x, digits = 3, ...) { cat("\n") } - if (!is.null(x$indicator)) { + if(!is.null(x$indicator)) { cat("Inclusion probabilities:\n") ind <- head(x$indicator, 6) ind[] <- lapply(ind, function(col) ifelse(is.na(col), "", round(col, digits))) print(ind) - if (nrow(x$indicator) > 6) cat("... (use `summary(fit)$indicator` to see full output)\n") + if(nrow(x$indicator) > 6) cat("... (use `summary(fit)$indicator` to see full output)\n") cat("Note: NA values are suppressed in the print table. They occur when an indicator\n") cat("was constant (all 0 or all 1) across all iterations, so sd/mcse/n_eff/Rhat\n") cat("are undefined; `summary(fit)$indicator` still contains the NA values.\n\n") } - if (!is.null(x$allocations)) { + if(!is.null(x$allocations)) { cat("Pairwise node co-clustering proportion:\n") print(round(head(x$allocations, 6), digits = digits)) - if (nrow(x$allocations) > 6) cat("... (use `summary(fit)$allocations` to see full output)\n") + if(nrow(x$allocations) > 6) cat("... (use `summary(fit)$allocations` to see full output)\n") cat("\n") } - if (!is.null(x$mean_allocations)) { + if(!is.null(x$mean_allocations)) { cat("Mean posterior node allocation vector:\n") print(round(head(x$mean_allocations, 6), digits = digits)) cat("Mode posterior node allocation vector:\n") @@ -140,10 +140,10 @@ print.summary.bgms <- function(x, digits = 3, ...) { cat("\n") } - if (!is.null(x$num_blocks)) { + if(!is.null(x$num_blocks)) { cat("Number of blocks and their posterior probability :\n") print(round(head(x$num_blocks, 6), digits = digits)) - if (nrow(x$num_blocks) > 6) cat("... (use `summary(fit)$posterior_num_blocks` to see full output)\n") + if(nrow(x$num_blocks) > 6) cat("... (use `summary(fit)$posterior_num_blocks` to see full output)\n") cat("\n") } @@ -152,7 +152,6 @@ print.summary.bgms <- function(x, digits = 3, ...) { } - #' @title Extract Coefficients from a bgms Object #' @name coef.bgms #' @description Returns the posterior mean thresholds, pairwise effects, and edge inclusion indicators from a \code{bgms} model fit. @@ -173,11 +172,11 @@ coef.bgms <- function(object, ...) { main = object$posterior_mean_main, pairwise = object$posterior_mean_pairwise ) - if (!is.null(object$posterior_mean_indicator)) { + if(!is.null(object$posterior_mean_indicator)) { out$indicator <- object$posterior_mean_indicator } - if (!is.null(object$posterior_mean_allocations)) { + if(!is.null(object$posterior_mean_allocations)) { out$mean_allocations <- object$posterior_mean_allocations out$mode_allocations <- object$posterior_mode_allocations out$num_blocks <- object$posterior_num_blocks @@ -189,7 +188,7 @@ coef.bgms <- function(object, ...) { .warning_issued <- FALSE warning_once <- function(msg) { - if (!.warning_issued) { + if(!.warning_issued) { warning(msg, call. = FALSE) .warning_issued <<- TRUE } diff --git a/R/bgms-package.R b/R/bgms-package.R index ae70418..f9a5828 100644 --- a/R/bgms-package.R +++ b/R/bgms-package.R @@ -63,4 +63,3 @@ ## usethis namespace: start ## usethis namespace: end NULL - diff --git a/R/data_utils.R b/R/data_utils.R index 0c47729..6559360 100644 --- a/R/data_utils.R +++ b/R/data_utils.R @@ -4,29 +4,50 @@ reformat_data = function(x, baseline_category) { if(na_action == "listwise") { # Check for missing values --------------------------------------------------- - missing_values = sapply(1:nrow(x), function(row){anyNA(x[row, ])}) - if(sum(missing_values) == nrow(x)) - stop(paste0("All rows in x contain at least one missing response.\n", - "You could try option na_action = impute.")) - if(sum(missing_values) > 1) - warning(paste0("There were ", - sum(missing_values), - " rows with missing observations in the input matrix x.\n", - "Since na_action = listwise these rows were excluded from the analysis."), - call. = FALSE) - if(sum(missing_values) == 1) - warning(paste0("There was one row with missing observations in the input matrix x.\n", - "Since na_action = listwise this row was excluded from \n", - "the analysis."), - call. = FALSE) + missing_values = sapply(1:nrow(x), function(row) { + anyNA(x[row, ]) + }) + if(sum(missing_values) == nrow(x)) { + stop(paste0( + "All rows in x contain at least one missing response.\n", + "You could try option na_action = impute." + )) + } + if(sum(missing_values) > 1) { + warning( + paste0( + "There were ", + sum(missing_values), + " rows with missing observations in the input matrix x.\n", + "Since na_action = listwise these rows were excluded from the analysis." + ), + call. = FALSE + ) + } + if(sum(missing_values) == 1) { + warning( + paste0( + "There was one row with missing observations in the input matrix x.\n", + "Since na_action = listwise this row was excluded from \n", + "the analysis." + ), + call. = FALSE + ) + } x = x[!missing_values, ] - if(ncol(x) < 2 || is.null(ncol(x))) - stop(paste0("After removing missing observations from the input matrix x,\n", - "there were less than two columns left in x.")) - if(nrow(x) < 2 || is.null(nrow(x))) - stop(paste0("After removing missing observations from the input matrix x,\n", - "there were less than two rows left in x.")) + if(ncol(x) < 2 || is.null(ncol(x))) { + stop(paste0( + "After removing missing observations from the input matrix x,\n", + "there were less than two columns left in x." + )) + } + if(nrow(x) < 2 || is.null(nrow(x))) { + stop(paste0( + "After removing missing observations from the input matrix x,\n", + "there were less than two rows left in x." + )) + } missing_index = matrix(NA, nrow = 1, ncol = 1) na_impute = FALSE @@ -44,11 +65,12 @@ reformat_data = function(x, if(length(mis) > 0) { for(i in 1:length(mis)) { cntr = cntr + 1 - missing_index[cntr, 1] = mis[i] - 1 #c++ index starts at 0 - missing_index[cntr, 2] = node - 1 #c++ index starts at 0 - x[mis[i], node] = sample(x[-mis, node], #start value for imputation - size = 1) - #This is non-zero if no zeroes are observed (we then collapse over zero below) + missing_index[cntr, 1] = mis[i] - 1 # c++ index starts at 0 + missing_index[cntr, 2] = node - 1 # c++ index starts at 0 + x[mis[i], node] = sample(x[-mis, node], # start value for imputation + size = 1 + ) + # This is non-zero if no zeroes are observed (we then collapse over zero below) } } } @@ -62,17 +84,20 @@ reformat_data = function(x, num_variables = ncol(x) num_categories = vector(length = num_variables) for(node in 1:num_variables) { - unq_vls = sort(unique(x[, node])) + unq_vls = sort(unique(x[, node])) mx_vl = max(unq_vls) # Check if observed responses are not all unique --------------------------- - if(mx_vl == nrow(x)) - stop(paste0("Only unique responses observed for variable ", - node, - ". We expect >= 1 observations per category.")) + if(mx_vl == nrow(x)) { + stop(paste0( + "Only unique responses observed for variable ", + node, + ". We expect >= 1 observations per category." + )) + } # Recode data -------------------------------------------------------------- - if(variable_bool[node]) {#Regular ordinal variable + if(variable_bool[node]) { # Regular ordinal variable # Ordinal (variable_bool == TRUE) or Blume-Capel (variable_bool == FALSE) if(length(unq_vls) != mx_vl + 1 || any(unq_vls != 0:mx_vl)) { y = x[, node] @@ -82,30 +107,35 @@ reformat_data = function(x, cntr = cntr + 1 } } - } else {#Blume-Capel ordinal variable + } else { # Blume-Capel ordinal variable # Check if observations are integer or can be recoded -------------------- - if (any(abs(unq_vls - round(unq_vls)) > .Machine$double.eps)) { + if(any(abs(unq_vls - round(unq_vls)) > .Machine$double.eps)) { int_unq_vls = unique(as.integer(unq_vls)) if(anyNA(int_unq_vls)) { stop(paste0( "The Blume-Capel model assumes that its observations are coded as integers, but \n", "the category scores for node ", node, " were not integer. An attempt to recode \n", "them to integer failed. Please inspect the documentation for the base R \n", - "function as.integer(), which bgm uses for recoding category scores.")) + "function as.integer(), which bgm uses for recoding category scores." + )) } if(length(int_unq_vls) != length(unq_vls)) { - stop(paste0("The Blume-Capel model assumes that its observations are coded as integers. The \n", - "category scores of the observations for node ", node, " were not integers. An \n", - "attempt to recode these observations as integers failed because, after rounding, \n", - "a single integer value was used for several observed score categories.")) + stop(paste0( + "The Blume-Capel model assumes that its observations are coded as integers. The \n", + "category scores of the observations for node ", node, " were not integers. An \n", + "attempt to recode these observations as integers failed because, after rounding, \n", + "a single integer value was used for several observed score categories." + )) } x[, node] = as.integer(x[, node]) - if(baseline_category[node] < 0 | baseline_category[node] > max(x[, node])) + if(baseline_category[node] < 0 | baseline_category[node] > max(x[, node])) { stop(paste0( "The reference category for the Blume-Capel variable ", node, "is outside its \n", - "range of observations.")) + "range of observations." + )) + } } # Check if observations start at zero and recode otherwise --------------- @@ -122,68 +152,82 @@ reformat_data = function(x, } check_range = length(unique(x[, node])) - if(check_range < 3) - stop(paste0("The Blume-Capel is only available for variables with more than one category \n", - "observed. There two or less categories observed for variable ", - node, - ".")) + if(check_range < 3) { + stop(paste0( + "The Blume-Capel is only available for variables with more than one category \n", + "observed. There two or less categories observed for variable ", + node, + "." + )) + } } # Warn that maximum category value is large -------------------------------- - num_categories[node] = max(x[,node]) + num_categories[node] = max(x[, node]) if(!variable_bool[node] & num_categories[node] > 10) { # Ordinal (variable_bool == TRUE) or Blume-Capel (variable_bool == FALSE) - warning(paste0("In the (pseudo) likelihood of Blume-Capel variables, the normalization constant \n", - "is a sum over all possible values of the ordinal variable. The range of \n", - "observed values, possibly after recoding to integers, is assumed to be the \n", - "number of possible response categories. For node ", node,", this range was \n", - "equal to ", num_categories[node], "which may cause the analysis to take some \n", - "time to run. Note that for the Blume-Capel model, the bgm function does not \n", - "collapse the categories that have no observations between zero and the last \n", - "category. This may explain the large discrepancy between the first and last \n", - "category values.")) + warning(paste0( + "In the (pseudo) likelihood of Blume-Capel variables, the normalization constant \n", + "is a sum over all possible values of the ordinal variable. The range of \n", + "observed values, possibly after recoding to integers, is assumed to be the \n", + "number of possible response categories. For node ", node, ", this range was \n", + "equal to ", num_categories[node], "which may cause the analysis to take some \n", + "time to run. Note that for the Blume-Capel model, the bgm function does not \n", + "collapse the categories that have no observations between zero and the last \n", + "category. This may explain the large discrepancy between the first and last \n", + "category values." + )) } # Check to see if not all responses are in one category -------------------- - if(num_categories[node] == 0) - stop(paste0("Only one value [", - unq_vls, - "] was observed for variable ", - node, - ".")) + if(num_categories[node] == 0) { + stop(paste0( + "Only one value [", + unq_vls, + "] was observed for variable ", + node, + "." + )) + } } if(check_fail_zero == TRUE) { if(length(failed_zeroes) == 1) { node = failed_zeroes[1] - warning(paste0("The bgm function assumes that the observed ordinal variables are integers and \n", - "that the lowest observed category score is zero. The lowest score for node \n", - node, " was recoded to zero for the analysis.\n", - "Note that bgm also recoded the corresponding reference category score to ", baseline_category[node], ".")) + warning(paste0( + "The bgm function assumes that the observed ordinal variables are integers and \n", + "that the lowest observed category score is zero. The lowest score for node \n", + node, " was recoded to zero for the analysis.\n", + "Note that bgm also recoded the corresponding reference category score to ", baseline_category[node], "." + )) } else { - warning(paste0("The bgm function assumes that the observed ordinal variables are integers and \n", - "that the lowest observed category score is zero. The lowest score for nodes \n", - paste(failed_zeroes, collapse = ","), " were recoded to zero for the analysis.\n", - "Note that bgm also recoded the corresponding reference category scores.")) + warning(paste0( + "The bgm function assumes that the observed ordinal variables are integers and \n", + "that the lowest observed category score is zero. The lowest score for nodes \n", + paste(failed_zeroes, collapse = ","), " were recoded to zero for the analysis.\n", + "Note that bgm also recoded the corresponding reference category scores." + )) } } - return(list(x = x, - num_categories = num_categories, - baseline_category = baseline_category, - missing_index = missing_index, - na_impute = na_impute)) + return(list( + x = x, + num_categories = num_categories, + baseline_category = baseline_category, + missing_index = missing_index, + na_impute = na_impute + )) } # Helper function for data checks data_check = function(data, name) { - if (!inherits(data, c("matrix", "data.frame"))) { + if(!inherits(data, c("matrix", "data.frame"))) { stop(paste(name, "must be a matrix or data.frame.")) } - if (inherits(data, "data.frame")) { + if(inherits(data, "data.frame")) { data = data.matrix(data) } - if (nrow(data) < 2 || ncol(data) < 2) { + if(nrow(data) < 2 || ncol(data) < 2) { stop(paste(name, "must have at least 2 rows and 2 columns.")) } return(data) @@ -192,10 +236,10 @@ data_check = function(data, name) { # Helper function for computing `counts_per_category` compute_counts_per_category = function(x, num_categories, group = NULL) { counts_per_category = list() - for (g in unique(group)) { + for(g in unique(group)) { counts_per_category_gr = matrix(0, nrow = max(num_categories), ncol = ncol(x)) - for (variable in seq_len(ncol(x))) { - for (category in seq_len(num_categories[variable])) { + for(variable in seq_len(ncol(x))) { + for(category in seq_len(num_categories[variable])) { counts_per_category_gr[category, variable] = sum(x[group == g, variable] == category) } } @@ -206,22 +250,22 @@ compute_counts_per_category = function(x, num_categories, group = NULL) { # Helper function for computing sufficient statistics for Blume-Capel variables compute_blume_capel_stats = function(x, baseline_category, ordinal_variable, group = NULL) { - if (is.null(group)) { # One-group design + if(is.null(group)) { # One-group design sufficient_stats = matrix(0, nrow = 2, ncol = ncol(x)) bc_vars = which(!ordinal_variable) - for (i in bc_vars) { + for(i in bc_vars) { sufficient_stats[1, i] = sum(x[, i]) - sufficient_stats[2, i] = sum((x[, i] - baseline_category[i]) ^ 2) + sufficient_stats[2, i] = sum((x[, i] - baseline_category[i])^2) } return(sufficient_stats) - } else { # Multi-group design + } else { # Multi-group design sufficient_stats = list() - for (g in unique(group)) { + for(g in unique(group)) { sufficient_stats_gr = matrix(0, nrow = 2, ncol = ncol(x)) bc_vars = which(!ordinal_variable) - for (i in bc_vars) { + for(i in bc_vars) { sufficient_stats_gr[1, i] = sum(x[group == g, i]) - sufficient_stats_gr[2, i] = sum((x[group == g, i] - baseline_category[i]) ^ 2) + sufficient_stats_gr[2, i] = sum((x[group == g, i] - baseline_category[i])^2) } sufficient_stats[[g]] = sufficient_stats_gr } @@ -233,7 +277,7 @@ compute_blume_capel_stats = function(x, baseline_category, ordinal_variable, gro compute_pairwise_stats <- function(x, group) { result <- vector("list", length(unique(group))) - for (g in unique(group)) { + for(g in unique(group)) { obs <- x[group == g, , drop = FALSE] # cross-product: gives number of co-occurrences of categories result[[g]] <- t(obs) %*% obs @@ -244,53 +288,80 @@ compute_pairwise_stats <- function(x, group) { compare_reformat_data = function( - x, - group, - na_action, - variable_bool, - baseline_category + x, + group, + na_action, + variable_bool, + baseline_category ) { if(na_action == "listwise") { # Check for missing values in x -------------------------------------------- - missing_values = sapply(1:nrow(x), function(row){anyNA(x[row, ])}) - if(sum(missing_values) == nrow(x)) - stop(paste0("All rows in x contain at least one missing response.\n", - "You could try option na_action = impute.")) - if(sum(missing_values) > 1) - warning(paste0("There were ", - sum(missing_values), - " rows with missing observations in the input matrix x.\n", - "Since na_action = listwise these rows were excluded from the analysis."), - call. = FALSE) - if(sum(missing_values) == 1) - warning(paste0("There was one row with missing observations in the input matrix x.\n", - "Since na_action = listwise this row was excluded from \n", - "the analysis."), - call. = FALSE) + missing_values = sapply(1:nrow(x), function(row) { + anyNA(x[row, ]) + }) + if(sum(missing_values) == nrow(x)) { + stop(paste0( + "All rows in x contain at least one missing response.\n", + "You could try option na_action = impute." + )) + } + if(sum(missing_values) > 1) { + warning( + paste0( + "There were ", + sum(missing_values), + " rows with missing observations in the input matrix x.\n", + "Since na_action = listwise these rows were excluded from the analysis." + ), + call. = FALSE + ) + } + if(sum(missing_values) == 1) { + warning( + paste0( + "There was one row with missing observations in the input matrix x.\n", + "Since na_action = listwise this row was excluded from \n", + "the analysis." + ), + call. = FALSE + ) + } x = x[!missing_values, ] group = group[!missing_values] - if(nrow(x) < 2 || is.null(nrow(x))) - stop(paste0("After removing missing observations from the input matrix x,\n", - "there were less than two rows left in x.")) + if(nrow(x) < 2 || is.null(nrow(x))) { + stop(paste0( + "After removing missing observations from the input matrix x,\n", + "there were less than two rows left in x." + )) + } unique_g = unique(group) - if(length(unique_g) == length(group)) - stop(paste0("After rows with missing observations were excluded, there were no groups, as \n", - "there were only unique values in the input g left.")) - if(length(unique_g) == 1) - stop(paste0("After rows with missing observations were excluded, there were no groups, as \n", - "there was only one value in the input g left.")) + if(length(unique_g) == length(group)) { + stop(paste0( + "After rows with missing observations were excluded, there were no groups, as \n", + "there were only unique values in the input g left." + )) + } + if(length(unique_g) == 1) { + stop(paste0( + "After rows with missing observations were excluded, there were no groups, as \n", + "there was only one value in the input g left." + )) + } g = group for(u in unique_g) { group[g == u] = which(unique_g == u) } tab = tabulate(group) - if(any(tab < 2)) - stop(paste0("After rows with missing observations were excluded, one or more groups, only \n", - "had one member in the input g.")) + if(any(tab < 2)) { + stop(paste0( + "After rows with missing observations were excluded, one or more groups, only \n", + "had one member in the input g." + )) + } missing_index = matrix(NA, nrow = 1, ncol = 1) na_impute = FALSE @@ -308,10 +379,10 @@ compare_reformat_data = function( if(length(mis) > 0) { for(i in 1:length(mis)) { cntr = cntr + 1 - missing_index[cntr, 1] = mis[i] - 1 #c++ index starts at 0 - missing_index[cntr, 2] = node - 1 #c++ index starts at 0 - x[mis[i], node] = sample(x[-mis, node], size = 1) #start value for imputation - #This is non-zero if no zeroes are observed (we then collapse over zero below) + missing_index[cntr, 1] = mis[i] - 1 # c++ index starts at 0 + missing_index[cntr, 2] = node - 1 # c++ index starts at 0 + x[mis[i], node] = sample(x[-mis, node], size = 1) # start value for imputation + # This is non-zero if no zeroes are observed (we then collapse over zero below) } } } @@ -323,24 +394,28 @@ compare_reformat_data = function( check_fail_zero = FALSE num_variables = ncol(x) - num_categories = vector (length = num_variables) + num_categories = vector(length = num_variables) for(node in 1:num_variables) { - unq_vls = sort(unique(x[, node])) + unq_vls = sort(unique(x[, node])) mx_vls = length(unq_vls) # Check if observed responses are not all unique --------------------------- - if(mx_vls == nrow(x)) - stop(paste0("Only unique responses observed for variable ", - node, - " in the matrix x (group 1). We expect >= 1 observations per category.")) + if(mx_vls == nrow(x)) { + stop(paste0( + "Only unique responses observed for variable ", + node, + " in the matrix x (group 1). We expect >= 1 observations per category." + )) + } # Recode data -------------------------------------------------------------- - if(variable_bool[node]) {#Regular ordinal variable + if(variable_bool[node]) { # Regular ordinal variable # Ordinal (variable_bool == TRUE) or Blume-Capel (variable_bool == FALSE) observed_scores = matrix(NA, - nrow = mx_vls, - ncol = max(group)) + nrow = mx_vls, + ncol = max(group) + ) for(value in unq_vls) { unique_g = unique(group) @@ -353,42 +428,46 @@ compare_reformat_data = function( xx = x[, node] cntr = -1 for(value in unq_vls) { - #Collapse categories when not observed in one or more groups. + # Collapse categories when not observed in one or more groups. if(sum(observed_scores[which(unq_vls == value), ]) == max(group)) { - cntr = cntr + 1 #increment score if category observed in all groups + cntr = cntr + 1 # increment score if category observed in all groups } x[xx == value, node] = max(0, cntr) } - - } else {#Blume-Capel ordinal variable + } else { # Blume-Capel ordinal variable # Check if observations are integer or can be recoded -------------------- - if (any(abs(unq_vls - round(unq_vls)) > .Machine$double.eps)) { + if(any(abs(unq_vls - round(unq_vls)) > .Machine$double.eps)) { int_unq_vls = unique(as.integer(unq_vls)) if(anyNA(int_unq_vls)) { stop(paste0( "The Blume-Capel model assumes that its observations are coded as integers, but \n", "the category scores for node ", node, " were not integer. An attempt to recode \n", "them to integer failed. Please inspect the documentation for the base R function \n", - "as.integer(), which bgmCompare uses for recoding category scores.")) + "as.integer(), which bgmCompare uses for recoding category scores." + )) } if(length(int_unq_vls) != length(unq_vls)) { - stop(paste0("The Blume-Capel model assumes that its observations are coded as integers. The \n", - "category scores of the observations for node ", node, " were not integers. An \n", - "attempt to recode these observations as integers failed because, after rounding,\n", - "a single integer value was used for several observed score categories.")) + stop(paste0( + "The Blume-Capel model assumes that its observations are coded as integers. The \n", + "category scores of the observations for node ", node, " were not integers. An \n", + "attempt to recode these observations as integers failed because, after rounding,\n", + "a single integer value was used for several observed score categories." + )) } x[, node] = as.integer(x[, node]) } - mi = min(x[,node]) + mi = min(x[, node]) - ma = max(x[,node]) + ma = max(x[, node]) - if(baseline_category[node] < mi | baseline_category[node] > ma) + if(baseline_category[node] < mi | baseline_category[node] > ma) { stop(paste0( "The reference category for the Blume-Capel variable ", node, "is outside its \n", - "range of observations in the matrices x (and y).")) + "range of observations in the matrices x (and y)." + )) + } # Check if observations start at zero and recode otherwise --------------- if(mi != 0) { @@ -405,11 +484,14 @@ compare_reformat_data = function( check_range = length(unique(x[, node])) - if(check_range < 3) - stop(paste0("The Blume-Capel is only available for variables with more than two categories \n", - "observed. There are two or less categories observed for variable ", - node, - ".")) + if(check_range < 3) { + stop(paste0( + "The Blume-Capel is only available for variables with more than two categories \n", + "observed. There are two or less categories observed for variable ", + node, + "." + )) + } } @@ -419,43 +501,52 @@ compare_reformat_data = function( if(!variable_bool[node] & max(num_categories[node]) > 10) { # Ordinal (variable_bool == TRUE) or Blume-Capel (variable_bool == FALSE) - warning(paste0("In the (pseudo) likelihood of Blume-Capel variables, the normalization constant \n", - "is a sum over all possible values of the ordinal variable. The range of \n", - "observed values, possibly after recoding to integers, is assumed to be the \n", - "number of possible response categories. For node ", node,", in group 1, this \n", - "range was equal to ", max(num_categories[node]), "which may cause the analysis to take some \n", - "time to run. Note that for the Blume-Capel model, the bgm function does not \n", - "collapse the categories that have no observations between zero and the last \n", - "category. This may explain the large discrepancy between the first and last \n", - "category values.")) + warning(paste0( + "In the (pseudo) likelihood of Blume-Capel variables, the normalization constant \n", + "is a sum over all possible values of the ordinal variable. The range of \n", + "observed values, possibly after recoding to integers, is assumed to be the \n", + "number of possible response categories. For node ", node, ", in group 1, this \n", + "range was equal to ", max(num_categories[node]), "which may cause the analysis to take some \n", + "time to run. Note that for the Blume-Capel model, the bgm function does not \n", + "collapse the categories that have no observations between zero and the last \n", + "category. This may explain the large discrepancy between the first and last \n", + "category values." + )) } # Check to see if not all responses are in one category -------------------- - if(any(num_categories[node] == 0)) + if(any(num_categories[node] == 0)) { stop(paste0("Only one value was observed for variable ", node, ".")) + } } if(check_fail_zero == TRUE) { if(length(failed_zeroes) == 1) { node = failed_zeroes[1] - warning(paste0("The bgm function assumes that the observed ordinal variables are integers and \n", - "that the lowest observed category score is zero. The lowest score for node \n", - node, " was recoded to zero for the analysis. Note that bgm also recoded the \n", - "the corresponding reference category score to ", baseline_category[node], ".")) + warning(paste0( + "The bgm function assumes that the observed ordinal variables are integers and \n", + "that the lowest observed category score is zero. The lowest score for node \n", + node, " was recoded to zero for the analysis. Note that bgm also recoded the \n", + "the corresponding reference category score to ", baseline_category[node], "." + )) } else { - warning(paste0("The bgm function assumes that the observed ordinal variables are integers and \n", - "that the lowest observed category score is zero. The lowest score for nodes \n", - paste(failed_zeroes, collapse = ","), " were recoded to zero for the analysis. \n", - "Note that bgm also recoded the corresponding reference category scores.")) + warning(paste0( + "The bgm function assumes that the observed ordinal variables are integers and \n", + "that the lowest observed category score is zero. The lowest score for nodes \n", + paste(failed_zeroes, collapse = ","), " were recoded to zero for the analysis. \n", + "Note that bgm also recoded the corresponding reference category scores." + )) } } - return(list(x = x, - group = group, - num_categories = num_categories, - baseline_category = baseline_category, - missing_index = missing_index, - na_impute = na_impute)) + return(list( + x = x, + group = group, + num_categories = num_categories, + baseline_category = baseline_category, + missing_index = missing_index, + na_impute = na_impute + )) } diff --git a/R/extractor_functions.R b/R/extractor_functions.R index ff64dd1..678a5f2 100644 --- a/R/extractor_functions.R +++ b/R/extractor_functions.R @@ -11,7 +11,7 @@ #' - `extract_category_thresholds()` – Posterior mean of category thresholds #' - `extract_indicator_priors()` – Prior structure used for edge indicators #' - `extract_sbm` – Extract stochastic block model parameters (if applicable) - #' +#' #' @name extractor_functions #' @title Extractor Functions for bgms Objects #' @keywords internal @@ -26,8 +26,8 @@ extract_arguments = function(bgms_object) { #' @rdname extractor_functions #' @export extract_arguments.bgms = function(bgms_object) { - if (!inherits(bgms_object, "bgms")) stop("Object must be of class 'bgms'.") - if (is.null(bgms_object$arguments)) { + if(!inherits(bgms_object, "bgms")) stop("Object must be of class 'bgms'.") + if(is.null(bgms_object$arguments)) { stop("Fit object predates bgms version 0.1.3. Upgrade the model output.") } return(bgms_object$arguments) @@ -37,8 +37,10 @@ extract_arguments.bgms = function(bgms_object) { #' @export extract_arguments.bgmCompare = function(bgms_object) { if(is.null(bgms_object$arguments)) { - stop(paste0("Extractor functions have been defined for bgms versions 0.1.3 and up but not \n", - "for older versions. The current fit object predates version 0.1.3.")) + stop(paste0( + "Extractor functions have been defined for bgms versions 0.1.3 and up but not \n", + "for older versions. The current fit object predates version 0.1.3." + )) } else { return(bgms_object$arguments) } @@ -59,20 +61,24 @@ extract_indicators = function(bgms_object) { extract_indicators.bgms = function(bgms_object) { arguments = extract_arguments(bgms_object) - if (!isTRUE(arguments$edge_selection)) { + if(!isTRUE(arguments$edge_selection)) { stop("To access edge indicators, the model must be run with edge_selection = TRUE.") } # Resolve indicator samples indicators_list = bgms_object$raw_samples$indicator - if (is.null(indicators_list)) { - if (!is.null(bgms_object$indicator)) { - lifecycle::deprecate_warn("0.1.6.0", "bgms_object$indicator", - "bgms_object$raw_samples$indicator") + if(is.null(indicators_list)) { + if(!is.null(bgms_object$indicator)) { + lifecycle::deprecate_warn( + "0.1.6.0", "bgms_object$indicator", + "bgms_object$raw_samples$indicator" + ) indicators_list = bgms_object$indicator - } else if (!is.null(bgms_object$gamma)) { - lifecycle::deprecate_warn("0.1.4.2", "bgms_object$gamma", - "bgms_object$raw_samples$indicator") + } else if(!is.null(bgms_object$gamma)) { + lifecycle::deprecate_warn( + "0.1.4.2", "bgms_object$gamma", + "bgms_object$raw_samples$indicator" + ) indicators_list = bgms_object$gamma } else { stop("No indicator samples found in this object.") @@ -84,7 +90,7 @@ extract_indicators.bgms = function(bgms_object) { # Assign column names if available param_names = bgms_object$raw_samples$parameter_names$indicator - if (!is.null(param_names)) { + if(!is.null(param_names)) { colnames(indicator_samples) = param_names } @@ -96,18 +102,18 @@ extract_indicators.bgms = function(bgms_object) { extract_indicators.bgmCompare = function(bgms_object) { arguments = extract_arguments(bgms_object) - if (!isTRUE(arguments$difference_selection)) { + if(!isTRUE(arguments$difference_selection)) { stop("To access difference indicators, the model must be run with difference_selection = TRUE.") } indicators_list = bgms_object$raw_samples$indicator - if (is.null(indicators_list)) { + if(is.null(indicators_list)) { stop("No indicator samples found in this object.") } indicator_samples = do.call(rbind, indicators_list) param_names = bgms_object$raw_samples$parameter_names$indicators - if (!is.null(param_names)) { + if(!is.null(param_names)) { colnames(indicator_samples) = param_names } return(indicator_samples) @@ -130,7 +136,7 @@ extract_posterior_inclusion_probabilities = function(bgms_object) { extract_posterior_inclusion_probabilities.bgms = function(bgms_object) { arguments = extract_arguments(bgms_object) - if (!isTRUE(arguments$edge_selection)) { + if(!isTRUE(arguments$edge_selection)) { stop("To estimate posterior inclusion probabilities, run bgm() with edge_selection = TRUE.") } @@ -139,25 +145,27 @@ extract_posterior_inclusion_probabilities.bgms = function(bgms_object) { edge_means = NULL # New format: use extract_indicators() - if (!is.null(bgms_object$raw_samples$indicator)) { + if(!is.null(bgms_object$raw_samples$indicator)) { indicator_samples = extract_indicators(bgms_object) edge_means = colMeans(indicator_samples) - } else if (!is.null(bgms_object$indicator)) { - lifecycle::deprecate_warn("0.1.6.0", - "bgms_object$indicator", - "bgms_object$raw_samples$indicator" + } else if(!is.null(bgms_object$indicator)) { + lifecycle::deprecate_warn( + "0.1.6.0", + "bgms_object$indicator", + "bgms_object$raw_samples$indicator" ) - if (!is.null(arguments$save) && isTRUE(arguments$save)) { + if(!is.null(arguments$save) && isTRUE(arguments$save)) { edge_means = colMeans(bgms_object$indicator) } else { edge_means = bgms_object$indicator } - } else if (!is.null(bgms_object$gamma)) { - lifecycle::deprecate_warn("0.1.4.2", - "bgms_object$gamma", - "bgms_object$raw_samples$indicator" + } else if(!is.null(bgms_object$gamma)) { + lifecycle::deprecate_warn( + "0.1.4.2", + "bgms_object$gamma", + "bgms_object$raw_samples$indicator" ) - if (!is.null(arguments$save) && isTRUE(arguments$save)) { + if(!is.null(arguments$save) && isTRUE(arguments$save)) { edge_means = colMeans(bgms_object$gamma) } else { edge_means = bgms_object$gamma @@ -186,45 +194,51 @@ extract_sbm = function(bgms_object) { #' @rdname extractor_functions #' @export extract_sbm.bgms = function(bgms_object) { - if (!inherits(bgms_object, "bgms")) stop("Object must be of class 'bgms'.") + if(!inherits(bgms_object, "bgms")) stop("Object must be of class 'bgms'.") # Checks ver = try(utils::packageVersion("bgms"), silent = TRUE) - if (inherits(ver, "try-error") || is.na(ver)) { + if(inherits(ver, "try-error") || is.na(ver)) { stop("Could not determine 'bgms' package version.") } - if (utils::compareVersion(as.character(ver), "0.1.6.0") < 0) { - stop(paste0("Extractor functions for the SBM prior are defined for bgms version 0.1.6.0. ", - "The current installed version is ", as.character(ver), ".")) + if(utils::compareVersion(as.character(ver), "0.1.6.0") < 0) { + stop(paste0( + "Extractor functions for the SBM prior are defined for bgms version 0.1.6.0. ", + "The current installed version is ", as.character(ver), "." + )) } arguments = extract_arguments(bgms_object) - if (!isTRUE(arguments$edge_selection)) { + if(!isTRUE(arguments$edge_selection)) { stop("To extract SBM summaries, run bgm() with edge_selection = TRUE.") } - if (!identical(arguments$edge_prior, "Stochastic-Block")) { - stop(paste0("edge_prior must be 'Stochastic-Block' (got '", - as.character(arguments$edge_prior), "').")) + if(!identical(arguments$edge_prior, "Stochastic-Block")) { + stop(paste0( + "edge_prior must be 'Stochastic-Block' (got '", + as.character(arguments$edge_prior), "')." + )) } - posterior_num_blocks = try(bgms_object$posterior_num_blocks, silent = TRUE) - posterior_mean_allocations = try(bgms_object$posterior_mean_allocations, silent = TRUE) - posterior_mode_allocations = try(bgms_object$posterior_mode_allocations, silent = TRUE) + posterior_num_blocks = try(bgms_object$posterior_num_blocks, silent = TRUE) + posterior_mean_allocations = try(bgms_object$posterior_mean_allocations, silent = TRUE) + posterior_mode_allocations = try(bgms_object$posterior_mode_allocations, silent = TRUE) posterior_mean_coclustering_matrix = try(bgms_object$posterior_mean_coclustering_matrix, silent = TRUE) - if (inherits(posterior_num_blocks, "try-error")) posterior_num_blocks = NULL - if (inherits(posterior_mean_allocations, "try-error")) posterior_mean_allocations = NULL - if (inherits(posterior_mode_allocations, "try-error")) posterior_mode_allocations = NULL - if (inherits(posterior_mean_coclustering_matrix, "try-error")) posterior_mean_coclustering_matrix = NULL - - if (is.null(posterior_num_blocks) || - is.null(posterior_mean_allocations) || - is.null(posterior_mode_allocations) || - is.null(posterior_mean_coclustering_matrix)) { - stop(paste0("SBM summaries not found in this object. Missing one or more of: ", - "posterior_num_blocks, posterior_mean_allocations, ", - "posterior_mode_allocations, posterior_mean_coclustering_matrix.")) + if(inherits(posterior_num_blocks, "try-error")) posterior_num_blocks = NULL + if(inherits(posterior_mean_allocations, "try-error")) posterior_mean_allocations = NULL + if(inherits(posterior_mode_allocations, "try-error")) posterior_mode_allocations = NULL + if(inherits(posterior_mean_coclustering_matrix, "try-error")) posterior_mean_coclustering_matrix = NULL + + if(is.null(posterior_num_blocks) || + is.null(posterior_mean_allocations) || + is.null(posterior_mode_allocations) || + is.null(posterior_mean_coclustering_matrix)) { + stop(paste0( + "SBM summaries not found in this object. Missing one or more of: ", + "posterior_num_blocks, posterior_mean_allocations, ", + "posterior_mode_allocations, posterior_mean_coclustering_matrix." + )) } @@ -242,35 +256,37 @@ extract_sbm.bgms = function(bgms_object) { extract_posterior_inclusion_probabilities.bgmCompare = function(bgms_object) { arguments = extract_arguments(bgms_object) - if (!isTRUE(arguments$difference_selection)) { + if(!isTRUE(arguments$difference_selection)) { stop("To estimate posterior inclusion probabilities, run bgmCompare() with difference_selection = TRUE.") } - var_names = arguments$data_columnnames + var_names = arguments$data_columnnames num_categories = as.integer(arguments$num_categories) - is_ordinal = as.logical(arguments$is_ordinal_variable) - num_groups = as.integer(arguments$num_groups) - num_variables = as.integer(arguments$num_variables) - projection = arguments$projection # [num_groups x (num_groups-1)] + is_ordinal = as.logical(arguments$is_ordinal_variable) + num_groups = as.integer(arguments$num_groups) + num_variables = as.integer(arguments$num_variables) + projection = arguments$projection # [num_groups x (num_groups-1)] # ---- helper: combine chains into [iter, chain, param], robust to vectors/1-col to_array3d = function(xlist) { - if (is.null(xlist)) return(NULL) + if(is.null(xlist)) { + return(NULL) + } stopifnot(length(xlist) >= 1) mats = lapply(xlist, function(x) { m = as.matrix(x) - if (is.null(dim(m))) m = matrix(m, ncol = 1L) + if(is.null(dim(m))) m = matrix(m, ncol = 1L) m }) - niter = nrow(mats[[1]]) + niter = nrow(mats[[1]]) nparam = ncol(mats[[1]]) arr = array(NA_real_, dim = c(niter, length(mats), nparam)) - for (c in seq_along(mats)) arr[, c, ] = mats[[c]] + for(c in seq_along(mats)) arr[, c, ] = mats[[c]] arr } array3d_ind = to_array3d(bgms_object$raw_samples$indicator) - if (!is.null(array3d_ind)) { + if(!is.null(array3d_ind)) { mean_ind = apply(array3d_ind, 3, mean) # reconstruct VxV matrix using the sampler’s interleaved order: @@ -278,15 +294,19 @@ extract_posterior_inclusion_probabilities.bgmCompare = function(bgms_object) { V = num_variables stopifnot(length(mean_ind) == V * (V + 1L) / 2L) - ind_mat = matrix(0, nrow = V, ncol = V, - dimnames = list(var_names, var_names)) + ind_mat = matrix(0, + nrow = V, ncol = V, + dimnames = list(var_names, var_names) + ) pos = 1L - for (i in seq_len(V)) { + for(i in seq_len(V)) { # diagonal (main indicator) - ind_mat[i, i] = mean_ind[pos]; pos = pos + 1L - if (i < V) { - for (j in (i + 1L):V) { - val = mean_ind[pos]; pos = pos + 1L + ind_mat[i, i] = mean_ind[pos] + pos = pos + 1L + if(i < V) { + for(j in (i + 1L):V) { + val = mean_ind[pos] + pos = pos + 1L ind_mat[i, j] = val ind_mat[j, i] = val } @@ -312,10 +332,9 @@ extract_indicator_priors = function(bgms_object) { #' @export extract_indicator_priors.bgms = function(bgms_object) { arguments = extract_arguments(bgms_object) - if (!isTRUE(arguments$edge_selection)) stop("No edge selection performed.") + if(!isTRUE(arguments$edge_selection)) stop("No edge selection performed.") - switch( - arguments$edge_prior, + switch(arguments$edge_prior, "Bernoulli" = list(type = "Bernoulli", prior_inclusion_probability = arguments$inclusion_probability), "Beta-Bernoulli" = list(type = "Beta-Bernoulli", alpha = arguments$beta_bernoulli_alpha, beta = arguments$beta_bernoulli_beta), "Stochastic-Block" = list( @@ -333,7 +352,7 @@ extract_indicator_priors.bgms = function(bgms_object) { extract_indicator_priors.bgmCompare = function(bgms_object) { arguments = extract_arguments(bgms_object) - if (!isTRUE(arguments$difference_selection)) { + if(!isTRUE(arguments$difference_selection)) { stop("The model ran without selection, so there are no indicator priors specified.") } @@ -341,7 +360,6 @@ extract_indicator_priors.bgmCompare = function(bgms_object) { } - #' @rdname extractor_functions #' @export extract_pairwise_interactions = function(bgms_object) { @@ -359,26 +377,26 @@ extract_pairwise_interactions.bgms = function(bgms_object) { nchains = length(bgms_object$raw_samples$pairwise) mat = NULL mats = bgms_object$raw_samples$pairwise - mat = do.call(rbind, mats) + mat = do.call(rbind, mats) edge_names = character() - for (i in 1:(num_vars - 1)) { - for (j in (i + 1):num_vars) { + for(i in 1:(num_vars - 1)) { + for(j in (i + 1):num_vars) { edge_names = c(edge_names, paste0(var_names[i], "-", var_names[j])) } } dimnames(mat) = list(paste0("iter", 1:nrow(mat)), edge_names) - } else if (!is.null(bgms_object$posterior_summary_pairwise)) { + } else if(!is.null(bgms_object$posterior_summary_pairwise)) { vec = bgms_object$posterior_summary_pairwise[, "mean"] mat = matrix(0, nrow = num_vars, ncol = num_vars) mat[upper.tri(mat)] = vec mat[lower.tri(mat)] = t(mat)[lower.tri(mat)] dimnames(mat) = list(var_names, var_names) - } else if (!is.null(bgms_object$posterior_mean_pairwise)) { + } else if(!is.null(bgms_object$posterior_mean_pairwise)) { mat = bgms_object$posterior_mean_pairwise dimnames(mat) = list(var_names, var_names) - } else if (!is.null(bgms_object$pairwise_effects)) { + } else if(!is.null(bgms_object$pairwise_effects)) { mat = bgms_object$pairwise_effects dimnames(mat) = list(var_names, var_names) } else { @@ -395,7 +413,7 @@ extract_pairwise_interactions.bgmCompare = function(bgms_object) { arguments = extract_arguments(bgms_object) if(is.null(bgms_object$raw_samples$pairwise)) { - stop('No raw samples found for the pairwise effects in the object.') + stop("No raw samples found for the pairwise effects in the object.") } pairwise_list = bgms_object$raw_samples$pairwise @@ -426,7 +444,7 @@ extract_category_thresholds.bgms = function(bgms_object) { arguments = extract_arguments(bgms_object) var_names = arguments$data_columnnames - if (!is.null(bgms_object$posterior_summary_main)) { + if(!is.null(bgms_object$posterior_summary_main)) { vec = bgms_object$posterior_summary_main[, "mean"] num_vars = arguments$num_variables variable_type = arguments$variable_type @@ -438,8 +456,8 @@ extract_category_thresholds.bgms = function(bgms_object) { mat = matrix(NA_real_, nrow = num_vars, ncol = max_cats) rownames(mat) = var_names pos = 1 - for (v in seq_len(num_vars)) { - if (variable_type[v] == "ordinal") { + for(v in seq_len(num_vars)) { + if(variable_type[v] == "ordinal") { k = num_cats[v] mat[v, 1:k] = vec[pos:(pos + k - 1)] pos = pos + k @@ -449,18 +467,20 @@ extract_category_thresholds.bgms = function(bgms_object) { } } return(mat) - } else if (!is.null(bgms_object$posterior_mean_main)) { + } else if(!is.null(bgms_object$posterior_mean_main)) { # Deprecated intermediate format - lifecycle::deprecate_warn("0.1.6.0", - "bgms_object$posterior_mean_main", - "bgms_object$posterior_summary_main" + lifecycle::deprecate_warn( + "0.1.6.0", + "bgms_object$posterior_mean_main", + "bgms_object$posterior_summary_main" ) mat = bgms_object$posterior_mean_main - } else if (!is.null(bgms_object$main_effects)) { + } else if(!is.null(bgms_object$main_effects)) { # Deprecated old format - lifecycle::deprecate_warn("0.1.4.2", - "bgms_object$main_effects", - "bgms_object$posterior_summary_main" + lifecycle::deprecate_warn( + "0.1.4.2", + "bgms_object$main_effects", + "bgms_object$posterior_summary_main" ) mat = bgms_object$main_effects } else { @@ -477,7 +497,7 @@ extract_category_thresholds.bgmCompare = function(bgms_object) { arguments = extract_arguments(bgms_object) if(is.null(bgms_object$raw_samples$main)) { - stop('No raw samples found for the main effects in the object.') + stop("No raw samples found for the main effects in the object.") } main_list = bgms_object$raw_samples$main @@ -508,21 +528,23 @@ extract_group_params.bgmCompare = function(bgms_object) { is_ordinal = as.logical(arguments$is_ordinal_variable) num_groups = as.integer(arguments$num_groups) num_variables = as.integer(arguments$num_variables) - projection = arguments$projection # [num_groups x (num_groups-1)] + projection = arguments$projection # [num_groups x (num_groups-1)] # ---- helper: combine chains into [iter, chain, param], robust to vectors/1-col to_array3d = function(xlist) { - if (is.null(xlist)) return(NULL) + if(is.null(xlist)) { + return(NULL) + } stopifnot(length(xlist) >= 1) mats = lapply(xlist, function(x) { m = as.matrix(x) - if (is.null(dim(m))) m = matrix(m, ncol = 1L) + if(is.null(dim(m))) m = matrix(m, ncol = 1L) m }) - niter = nrow(mats[[1]]) + niter = nrow(mats[[1]]) nparam = ncol(mats[[1]]) arr = array(NA_real_, dim = c(niter, length(mats), nparam)) - for (c in seq_along(mats)) arr[, c, ] = mats[[c]] + for(c in seq_along(mats)) arr[, c, ] = mats[[c]] arr } @@ -539,20 +561,22 @@ extract_group_params.bgmCompare = function(bgms_object) { # row names in sampler row order rownames(main_mat) = unlist(lapply(seq_len(num_variables), function(v) { - if (is_ordinal[v]) { + if(is_ordinal[v]) { paste0(var_names[v], "(c", seq_len(num_categories[v]), ")") } else { - c(paste0(var_names[v], "(linear)"), - paste0(var_names[v], "(quadratic)")) + c( + paste0(var_names[v], "(linear)"), + paste0(var_names[v], "(quadratic)") + ) } })) colnames(main_mat) = c("baseline", paste0("diff", seq_len(num_groups - 1L))) # group-specific main effects: baseline + P %*% diffs main_effects_groups = matrix(NA_real_, nrow = num_main, ncol = num_groups) - for (r in seq_len(num_main)) { + for(r in seq_len(num_main)) { baseline = main_mat[r, 1] - diffs = main_mat[r, -1, drop = TRUE] + diffs = main_mat[r, -1, drop = TRUE] main_effects_groups[r, ] = baseline + as.vector(projection %*% diffs) } rownames(main_effects_groups) = rownames(main_mat) @@ -571,9 +595,9 @@ extract_group_params.bgmCompare = function(bgms_object) { # row names in sampler row order (upper-tri i= 2L) { - for (i in 1L:(num_variables - 1L)) { - for (j in (i + 1L):num_variables) { + if(num_variables >= 2L) { + for(i in 1L:(num_variables - 1L)) { + for(j in (i + 1L):num_variables) { pair_names = c(pair_names, paste0(var_names[i], "-", var_names[j])) } } @@ -583,9 +607,9 @@ extract_group_params.bgmCompare = function(bgms_object) { # group-specific pairwise effects pairwise_effects_groups = matrix(NA_real_, nrow = num_pair, ncol = num_groups) - for (r in seq_len(num_pair)) { + for(r in seq_len(num_pair)) { baseline = pairwise_mat[r, 1] - diffs = pairwise_mat[r, -1, drop = TRUE] + diffs = pairwise_mat[r, -1, drop = TRUE] pairwise_effects_groups[r, ] = baseline + as.vector(projection %*% diffs) } rownames(pairwise_effects_groups) = rownames(pairwise_mat) @@ -609,4 +633,4 @@ extract_edge_indicators = function(bgms_object) { extract_pairwise_thresholds = function(bgms_object) { lifecycle::deprecate_warn("0.1.4.2", "extract_pairwise_thresholds()", "extract_category_thresholds()") extract_category_thresholds(bgms_object) -} \ No newline at end of file +} diff --git a/R/function_input_utils.R b/R/function_input_utils.R index 4377b10..4dde1a1 100644 --- a/R/function_input_utils.R +++ b/R/function_input_utils.R @@ -1,12 +1,12 @@ check_positive_integer = function(value, name) { - if (!is.numeric(value) || abs(value - round(value)) > .Machine$double.eps || value <= 0) { + if(!is.numeric(value) || abs(value - round(value)) > .Machine$double.eps || value <= 0) { stop(sprintf("Parameter `%s` must be a positive integer. Got: %s", name, value)) } } # Helper function for validating non-negative integers check_non_negative_integer = function(value, name) { - if (!is.numeric(value) || abs(value - round(value)) > .Machine$double.eps || value < 0) { + if(!is.numeric(value) || abs(value - round(value)) > .Machine$double.eps || value < 0) { stop(sprintf("Parameter `%s` must be a non-negative integer. Got: %s", name, value)) } } @@ -14,7 +14,7 @@ check_non_negative_integer = function(value, name) { # Helper function for validating logical inputs check_logical = function(value, name) { value = as.logical(value) - if (is.na(value)) { + if(is.na(value)) { stop(sprintf("Parameter `%s` must be TRUE or FALSE. Got: %s", name, value)) } return(value) @@ -35,81 +35,114 @@ check_model = function(x, beta_bernoulli_beta_between = 1, dirichlet_alpha = dirichlet_alpha, lambda = lambda) { - - #Check variable type input --------------------------------------------------- + # Check variable type input --------------------------------------------------- if(length(variable_type) == 1) { variable_input = variable_type - variable_type = try(match.arg(arg = variable_type, - choices = c("ordinal", "blume-capel")), - silent = TRUE) - if(inherits(variable_type, what = "try-error")) - stop(paste0("The bgm function supports variables of type ordinal and blume-capel, \n", - "but not of type ", - variable_input, ".")) + variable_type = try( + match.arg( + arg = variable_type, + choices = c("ordinal", "blume-capel") + ), + silent = TRUE + ) + if(inherits(variable_type, what = "try-error")) { + stop(paste0( + "The bgm function supports variables of type ordinal and blume-capel, \n", + "but not of type ", + variable_input, "." + )) + } variable_bool = (variable_type == "ordinal") variable_bool = rep(variable_bool, ncol(x)) } else { - if(length(variable_type) != ncol(x)) - stop(paste0("The variable type vector variable_type should be either a single character\n", - "string or a vector of character strings of length p.")) + if(length(variable_type) != ncol(x)) { + stop(paste0( + "The variable type vector variable_type should be either a single character\n", + "string or a vector of character strings of length p." + )) + } variable_input = unique(variable_type) - variable_type = try(match.arg(arg = variable_type, - choices = c("ordinal", "blume-capel"), - several.ok = TRUE), silent = TRUE) - - if(inherits(variable_type, what = "try-error")) - stop(paste0("The bgm function supports variables of type ordinal and blume-capel, \n", - "but not of type ", - paste0(variable_input, collapse = ", "), ".")) + variable_type = try(match.arg( + arg = variable_type, + choices = c("ordinal", "blume-capel"), + several.ok = TRUE + ), silent = TRUE) + + if(inherits(variable_type, what = "try-error")) { + stop(paste0( + "The bgm function supports variables of type ordinal and blume-capel, \n", + "but not of type ", + paste0(variable_input, collapse = ", "), "." + )) + } num_types = sapply(variable_input, function(type) { - tmp = try(match.arg(arg = type, - choices = c("ordinal", "blume-capel")), - silent = TRUE) + tmp = try( + match.arg( + arg = type, + choices = c("ordinal", "blume-capel") + ), + silent = TRUE + ) inherits(tmp, what = "try-error") }) - if(length(variable_type) != ncol(x)) - stop(paste0("The bgm function supports variables of type ordinal and blume-capel, \n", - "but not of type ", - paste0(variable_input[num_types], collapse = ", "), ".")) + if(length(variable_type) != ncol(x)) { + stop(paste0( + "The bgm function supports variables of type ordinal and blume-capel, \n", + "but not of type ", + paste0(variable_input[num_types], collapse = ", "), "." + )) + } variable_bool = (variable_type == "ordinal") } - #Check Blume-Capel variable input -------------------------------------------- + # Check Blume-Capel variable input -------------------------------------------- if(any(!variable_bool)) { # Ordinal (variable_bool == TRUE) or Blume-Capel (variable_bool == FALSE) - if(!hasArg("baseline_category")) + if(!hasArg("baseline_category")) { stop("The argument baseline_category is required for Blume-Capel variables.") + } - if(length(baseline_category) != ncol(x) && length(baseline_category) != 1) - stop(paste0("The argument baseline_category for the Blume-Capel model needs to be a \n", - "single integer or a vector of integers of length p.")) + if(length(baseline_category) != ncol(x) && length(baseline_category) != 1) { + stop(paste0( + "The argument baseline_category for the Blume-Capel model needs to be a \n", + "single integer or a vector of integers of length p." + )) + } if(length(baseline_category) == 1) { - #Check if the input is integer ------------------------------------------- + # Check if the input is integer ------------------------------------------- integer_check = try(as.integer(baseline_category), silent = TRUE) - if(is.na(integer_check)) - stop(paste0("The baseline_category argument for the Blume-Capel model contains either \n", - "a missing value or a value that could not be forced into an integer value.")) + if(is.na(integer_check)) { + stop(paste0( + "The baseline_category argument for the Blume-Capel model contains either \n", + "a missing value or a value that could not be forced into an integer value." + )) + } integer_check = baseline_category - round(baseline_category) - if(integer_check > .Machine$double.eps) + if(integer_check > .Machine$double.eps) { stop("Reference category needs to an integer value or a vector of integers of length p.") + } baseline_category = rep.int(baseline_category, times = ncol(x)) } - #Check if the input is integer ------------------------------------------- + # Check if the input is integer ------------------------------------------- blume_capel_variables = which(!variable_bool) # Ordinal (variable_bool == TRUE) or Blume-Capel (variable_bool == FALSE) integer_check = try(as.integer(baseline_category[blume_capel_variables]), - silent = TRUE) - if(anyNA(integer_check)) - stop(paste0("The baseline_category argument for the Blume-Capel model contains either \n", - "missing values or values that could not be forced into an integer value.")) + silent = TRUE + ) + if(anyNA(integer_check)) { + stop(paste0( + "The baseline_category argument for the Blume-Capel model contains either \n", + "missing values or values that could not be forced into an integer value." + )) + } integer_check = baseline_category[blume_capel_variables] - round(baseline_category[blume_capel_variables]) @@ -117,11 +150,15 @@ check_model = function(x, if(any(integer_check > .Machine$double.eps)) { non_integers = blume_capel_variables[integer_check > .Machine$double.eps] if(length(non_integers) > 1) { - stop(paste0("The entries in baseline_category for variables ", - paste0(non_integers, collapse = ", "), " need to be integer.")) + stop(paste0( + "The entries in baseline_category for variables ", + paste0(non_integers, collapse = ", "), " need to be integer." + )) } else { - stop(paste0("The entry in baseline_category for variable ", - non_integers, " needs to be an integer.")) + stop(paste0( + "The entry in baseline_category for variable ", + non_integers, " needs to be an integer." + )) } } @@ -130,158 +167,193 @@ check_model = function(x, if(any(baseline_category < variable_lower) | any(baseline_category > variable_upper)) { out_of_range = which(baseline_category < variable_lower | baseline_category > variable_upper) - stop(paste0("The Blume-Capel model assumes that the reference category is within the range \n", - "of the observed category scores. This was not the case for variable(s) \n", - paste0(out_of_range, collapse =", "), - ".")) + stop(paste0( + "The Blume-Capel model assumes that the reference category is within the range \n", + "of the observed category scores. This was not the case for variable(s) \n", + paste0(out_of_range, collapse = ", "), + "." + )) } - } else { baseline_category = rep.int(0, times = ncol(x)) } - #Check prior set-up for the interaction parameters --------------------------- - if(pairwise_scale <= 0 || is.na(pairwise_scale) || is.infinite(pairwise_scale)) + # Check prior set-up for the interaction parameters --------------------------- + if(pairwise_scale <= 0 || is.na(pairwise_scale) || is.infinite(pairwise_scale)) { stop("The scale of the Cauchy prior needs to be positive.") + } - #Check prior set-up for the threshold parameters ----------------------------- - if(main_alpha <= 0 | !is.finite(main_alpha)) + # Check prior set-up for the threshold parameters ----------------------------- + if(main_alpha <= 0 | !is.finite(main_alpha)) { stop("Parameter main_alpha needs to be positive.") - if(main_beta <= 0 | !is.finite(main_beta)) + } + if(main_beta <= 0 | !is.finite(main_beta)) { stop("Parameter main_beta needs to be positive.") + } - #Check set-up for the Bayesian edge selection model -------------------------- + # Check set-up for the Bayesian edge selection model -------------------------- edge_selection = as.logical(edge_selection) - if(is.na(edge_selection)) + if(is.na(edge_selection)) { stop("The parameter edge_selection needs to be TRUE or FALSE.") + } if(edge_selection == TRUE) { - #Check prior set-up for the edge indicators -------------------------------- + # Check prior set-up for the edge indicators -------------------------------- edge_prior = match.arg(edge_prior) if(edge_prior == "Bernoulli") { if(length(inclusion_probability) == 1) { theta = inclusion_probability[1] - if(is.na(theta) || is.null(theta)) + if(is.na(theta) || is.null(theta)) { stop("There is no value specified for the inclusion probability.") - if(theta <= 0) + } + if(theta <= 0) { stop("The inclusion probability needs to be positive.") - if(theta > 1) + } + if(theta > 1) { stop("The inclusion probability cannot exceed the value one.") - if(theta == 1) + } + if(theta == 1) { stop("The inclusion probability cannot equal one.") + } theta = matrix(theta, nrow = ncol(x), ncol = ncol(x)) } else { if(!inherits(inclusion_probability, what = "matrix") && - !inherits(inclusion_probability, what = "data.frame")) + !inherits(inclusion_probability, what = "data.frame")) { stop("The input for the inclusion probability argument needs to be a single number, matrix, or dataframe.") + } if(inherits(inclusion_probability, what = "data.frame")) { theta = data.matrix(inclusion_probability) } else { theta = inclusion_probability } - if(!isSymmetric(theta)) + if(!isSymmetric(theta)) { stop("The inclusion probability matrix needs to be symmetric.") - if(ncol(theta) != ncol(x)) + } + if(ncol(theta) != ncol(x)) { stop("The inclusion probability matrix needs to have as many rows (columns) as there are variables in the data.") + } if(anyNA(theta[lower.tri(theta)]) || - any(is.null(theta[lower.tri(theta)]))) + any(is.null(theta[lower.tri(theta)]))) { stop("One or more elements of the elements in inclusion probability matrix are not specified.") - if(any(theta[lower.tri(theta)] <= 0)) - stop(paste0("The inclusion probability matrix contains negative or zero values;\n", - "inclusion probabilities need to be positive.")) - if(any(theta[lower.tri(theta)] >= 1)) - stop(paste0("The inclusion probability matrix contains values greater than or equal to one;\n", - "inclusion probabilities cannot exceed or equal the value one.")) + } + if(any(theta[lower.tri(theta)] <= 0)) { + stop(paste0( + "The inclusion probability matrix contains negative or zero values;\n", + "inclusion probabilities need to be positive." + )) + } + if(any(theta[lower.tri(theta)] >= 1)) { + stop(paste0( + "The inclusion probability matrix contains values greater than or equal to one;\n", + "inclusion probabilities cannot exceed or equal the value one." + )) + } } } if(edge_prior == "Beta-Bernoulli") { theta = matrix(0.5, nrow = ncol(x), ncol = ncol(x)) - if(beta_bernoulli_alpha <= 0 || beta_bernoulli_beta <= 0) + if(beta_bernoulli_alpha <= 0 || beta_bernoulli_beta <= 0) { stop("The scale parameters of the beta distribution need to be positive.") - if(!is.finite(beta_bernoulli_alpha) || !is.finite(beta_bernoulli_beta)) + } + if(!is.finite(beta_bernoulli_alpha) || !is.finite(beta_bernoulli_beta)) { stop("The scale parameters of the beta distribution need to be finite.") + } if(is.na(beta_bernoulli_alpha) || is.na(beta_bernoulli_beta) || - is.null(beta_bernoulli_alpha) || is.null(beta_bernoulli_beta)) + is.null(beta_bernoulli_alpha) || is.null(beta_bernoulli_beta)) { stop("Values for both scale parameters of the beta distribution need to be specified.") + } } if(edge_prior == "Stochastic-Block") { theta = matrix(0.5, nrow = ncol(x), ncol = ncol(x)) # Check that all beta parameters are provided - if (is.null(beta_bernoulli_alpha) || is.null(beta_bernoulli_beta) || - is.null(beta_bernoulli_alpha_between) || is.null(beta_bernoulli_beta_between)) { - stop("The Stochastic-Block prior requires all four beta parameters: ", - "beta_bernoulli_alpha, beta_bernoulli_beta, ", - "beta_bernoulli_alpha_between, and beta_bernoulli_beta_between.") + if(is.null(beta_bernoulli_alpha) || is.null(beta_bernoulli_beta) || + is.null(beta_bernoulli_alpha_between) || is.null(beta_bernoulli_beta_between)) { + stop( + "The Stochastic-Block prior requires all four beta parameters: ", + "beta_bernoulli_alpha, beta_bernoulli_beta, ", + "beta_bernoulli_alpha_between, and beta_bernoulli_beta_between." + ) } # Check that all beta parameters are positive - if (beta_bernoulli_alpha <= 0 || beta_bernoulli_beta <= 0 || - beta_bernoulli_alpha_between <= 0 || beta_bernoulli_beta_between <= 0 || - dirichlet_alpha <= 0 || lambda <= 0) { + if(beta_bernoulli_alpha <= 0 || beta_bernoulli_beta <= 0 || + beta_bernoulli_alpha_between <= 0 || beta_bernoulli_beta_between <= 0 || + dirichlet_alpha <= 0 || lambda <= 0) { stop("The parameters of the beta and Dirichlet distributions need to be positive.") } # Check that all beta parameters are finite - if (!is.finite(beta_bernoulli_alpha) || !is.finite(beta_bernoulli_beta) || - !is.finite(beta_bernoulli_alpha_between) || !is.finite(beta_bernoulli_beta_between) || - !is.finite(dirichlet_alpha) || !is.finite(lambda)) { - stop("The shape parameters of the beta distribution, the concentration parameter of the Dirichlet distribution, ", - "and the rate parameter of the Poisson distribution need to be finite.") + if(!is.finite(beta_bernoulli_alpha) || !is.finite(beta_bernoulli_beta) || + !is.finite(beta_bernoulli_alpha_between) || !is.finite(beta_bernoulli_beta_between) || + !is.finite(dirichlet_alpha) || !is.finite(lambda)) { + stop( + "The shape parameters of the beta distribution, the concentration parameter of the Dirichlet distribution, ", + "and the rate parameter of the Poisson distribution need to be finite." + ) } # Check for NAs - if (is.na(beta_bernoulli_alpha) || is.na(beta_bernoulli_beta) || - is.na(beta_bernoulli_alpha_between) || is.na(beta_bernoulli_beta_between) || - is.na(dirichlet_alpha) || is.na(lambda)) { - stop("Values for all shape parameters of the beta distribution, the concentration parameter of the Dirichlet distribution, ", - "and the rate parameter of the Poisson distribution cannot be NA.") + if(is.na(beta_bernoulli_alpha) || is.na(beta_bernoulli_beta) || + is.na(beta_bernoulli_alpha_between) || is.na(beta_bernoulli_beta_between) || + is.na(dirichlet_alpha) || is.na(lambda)) { + stop( + "Values for all shape parameters of the beta distribution, the concentration parameter of the Dirichlet distribution, ", + "and the rate parameter of the Poisson distribution cannot be NA." + ) } } - }else { + } else { theta = matrix(0.5, nrow = 1, ncol = 1) edge_prior = "Not Applicable" } - return(list(variable_bool = variable_bool, - baseline_category = baseline_category, - edge_selection = edge_selection, - edge_prior = edge_prior, - inclusion_probability = theta)) + return(list( + variable_bool = variable_bool, + baseline_category = baseline_category, + edge_selection = edge_selection, + edge_prior = edge_prior, + inclusion_probability = theta + )) } - check_compare_model = function( - x, - y, - group_indicator, - difference_selection, - variable_type, - baseline_category, - difference_scale = 2.5, - difference_prior = c("Bernoulli", "Beta-Bernoulli"), - difference_probability = 0.5, - beta_bernoulli_alpha = 1, - beta_bernoulli_beta = 1, - pairwise_scale = 2.5, - main_alpha = 0.5, - main_beta = 0.5 + x, + y, + group_indicator, + difference_selection, + variable_type, + baseline_category, + difference_scale = 2.5, + difference_prior = c("Bernoulli", "Beta-Bernoulli"), + difference_probability = 0.5, + beta_bernoulli_alpha = 1, + beta_bernoulli_beta = 1, + pairwise_scale = 2.5, + main_alpha = 0.5, + main_beta = 0.5 ) { - if(!is.null(group_indicator)) { unique_g = unique(group_indicator) - if(length(unique_g) == 0) - stop(paste0("The bgmCompare function expects at least two groups, but the input group_indicator contains\n", - "no group value.")) - if(length(unique_g) == 1) - stop(paste0("The bgmCompare function expects at least two groups, but the input group_indicator contains\n", - "only one group value.")) - if(length(unique_g) == length(group_indicator)) + if(length(unique_g) == 0) { + stop(paste0( + "The bgmCompare function expects at least two groups, but the input group_indicator contains\n", + "no group value." + )) + } + if(length(unique_g) == 1) { + stop(paste0( + "The bgmCompare function expects at least two groups, but the input group_indicator contains\n", + "only one group value." + )) + } + if(length(unique_g) == length(group_indicator)) { stop("The input group_indicator contains only unique group values.") + } group = group_indicator for(u in unique_g) { @@ -289,87 +361,122 @@ check_compare_model = function( } tab = tabulate(group) - if(any(tab < 2)) + if(any(tab < 2)) { stop("One or more groups only had one member in the input group_indicator.") + } } else { group = c(rep.int(1, times = nrow(x)), rep.int(2, times = nrow(y))) x = rbind(x, y) } - #Check variable type input --------------------------------------------------- + # Check variable type input --------------------------------------------------- if(length(variable_type) == 1) { variable_input = variable_type - variable_type = try(match.arg(arg = variable_type, - choices = c("ordinal", "blume-capel")), - silent = TRUE) - if(inherits(variable_type, what = "try-error")) - stop(paste0("The bgmCompare function supports variables of type ordinal and blume-capel, \n", - "but not of type ", - variable_input, ".")) + variable_type = try( + match.arg( + arg = variable_type, + choices = c("ordinal", "blume-capel") + ), + silent = TRUE + ) + if(inherits(variable_type, what = "try-error")) { + stop(paste0( + "The bgmCompare function supports variables of type ordinal and blume-capel, \n", + "but not of type ", + variable_input, "." + )) + } variable_bool = (variable_type == "ordinal") variable_bool = rep(variable_bool, ncol(x)) } else { - if(length(variable_type) != ncol(x)) - stop(paste0("The variable type vector variable_type should be either a single character\n", - "string or a vector of character strings of length p.")) + if(length(variable_type) != ncol(x)) { + stop(paste0( + "The variable type vector variable_type should be either a single character\n", + "string or a vector of character strings of length p." + )) + } variable_input = unique(variable_type) - variable_type = try(match.arg(arg = variable_type, - choices = c("ordinal", "blume-capel"), - several.ok = TRUE), silent = TRUE) - - if(inherits(variable_type, what = "try-error")) - stop(paste0("The bgmCompare function supports variables of type ordinal and blume-capel, \n", - "but not of type ", - paste0(variable_input, collapse = ", "), ".")) + variable_type = try(match.arg( + arg = variable_type, + choices = c("ordinal", "blume-capel"), + several.ok = TRUE + ), silent = TRUE) + + if(inherits(variable_type, what = "try-error")) { + stop(paste0( + "The bgmCompare function supports variables of type ordinal and blume-capel, \n", + "but not of type ", + paste0(variable_input, collapse = ", "), "." + )) + } num_types = sapply(variable_input, function(type) { - tmp = try(match.arg(arg = type, - choices = c("ordinal", "blume-capel")), - silent = TRUE) + tmp = try( + match.arg( + arg = type, + choices = c("ordinal", "blume-capel") + ), + silent = TRUE + ) inherits(tmp, what = "try-error") }) - if(length(variable_type) != ncol(x)) - stop(paste0("The bgmCompare function supports variables of type ordinal and blume-capel, \n", - "but not of type ", - paste0(variable_input[num_types], collapse = ", "), ".")) + if(length(variable_type) != ncol(x)) { + stop(paste0( + "The bgmCompare function supports variables of type ordinal and blume-capel, \n", + "but not of type ", + paste0(variable_input[num_types], collapse = ", "), "." + )) + } variable_bool = (variable_type == "ordinal") } - #Check Blume-Capel variable input -------------------------------------------- + # Check Blume-Capel variable input -------------------------------------------- if(any(!variable_bool)) { # Ordinal (variable_bool == TRUE) or Blume-Capel (variable_bool == FALSE) - if(!hasArg("baseline_category")) + if(!hasArg("baseline_category")) { stop("The argument baseline_category is required for Blume-Capel variables.") + } - if(length(baseline_category) != ncol(x) && length(baseline_category) != 1) - stop(paste0("The argument baseline_category for the Blume-Capel model needs to be a \n", - "single integer or a vector of integers of length p.")) + if(length(baseline_category) != ncol(x) && length(baseline_category) != 1) { + stop(paste0( + "The argument baseline_category for the Blume-Capel model needs to be a \n", + "single integer or a vector of integers of length p." + )) + } if(length(baseline_category) == 1) { - #Check if the input is integer ------------------------------------------- + # Check if the input is integer ------------------------------------------- integer_check = try(as.integer(baseline_category), silent = TRUE) - if(is.na(integer_check)) - stop(paste0("The baseline_category argument for the Blume-Capel model contains either \n", - "a missing value or a value that could not be forced into an integer value.")) + if(is.na(integer_check)) { + stop(paste0( + "The baseline_category argument for the Blume-Capel model contains either \n", + "a missing value or a value that could not be forced into an integer value." + )) + } integer_check = baseline_category - round(baseline_category) - if(integer_check > .Machine$double.eps) + if(integer_check > .Machine$double.eps) { stop("Reference category needs to an integer value or a vector of integers of length p.") + } baseline_category = rep.int(baseline_category, times = ncol(x)) } - #Check if the input is integer ------------------------------------------- + # Check if the input is integer ------------------------------------------- blume_capel_variables = which(!variable_bool) # Ordinal (variable_bool == TRUE) or Blume-Capel (variable_bool == FALSE) integer_check = try(as.integer(baseline_category[blume_capel_variables]), - silent = TRUE) - if(anyNA(integer_check)) - stop(paste0("The baseline_category argument for the Blume-Capel model contains either \n", - "missing values or values that could not be forced into an integer value.")) + silent = TRUE + ) + if(anyNA(integer_check)) { + stop(paste0( + "The baseline_category argument for the Blume-Capel model contains either \n", + "missing values or values that could not be forced into an integer value." + )) + } integer_check = baseline_category[blume_capel_variables] - round(baseline_category[blume_capel_variables]) @@ -377,11 +484,15 @@ check_compare_model = function( if(any(integer_check > .Machine$double.eps)) { non_integers = blume_capel_variables[integer_check > .Machine$double.eps] if(length(non_integers) > 1) { - stop(paste0("The entries in baseline_category for variables ", - paste0(non_integers, collapse = ", "), " need to be integer.")) + stop(paste0( + "The entries in baseline_category for variables ", + paste0(non_integers, collapse = ", "), " need to be integer." + )) } else { - stop(paste0("The entry in baseline_category for variable ", - non_integers, " needs to be an integer.")) + stop(paste0( + "The entry in baseline_category for variable ", + non_integers, " needs to be an integer." + )) } } @@ -390,57 +501,69 @@ check_compare_model = function( if(any(baseline_category < variable_lower) | any(baseline_category > variable_upper)) { out_of_range = which(baseline_category < variable_lower | baseline_category > variable_upper) - stop(paste0("The Blume-Capel model assumes that the reference category is within the range \n", - "of the observed category scores. This was not the case for variable(s) \n", - paste0(out_of_range, collapse =", "), - ".")) + stop(paste0( + "The Blume-Capel model assumes that the reference category is within the range \n", + "of the observed category scores. This was not the case for variable(s) \n", + paste0(out_of_range, collapse = ", "), + "." + )) } } else { baseline_category = rep.int(0, times = ncol(x)) } - #Check prior set-up for the interaction parameters --------------------------- - if(pairwise_scale <= 0 || is.na(pairwise_scale) || is.infinite(pairwise_scale)) + # Check prior set-up for the interaction parameters --------------------------- + if(pairwise_scale <= 0 || is.na(pairwise_scale) || is.infinite(pairwise_scale)) { stop("The scale of the Cauchy prior for the interactions needs to be positive.") + } - #Check prior set-up for the interaction differences -------------------------- - if(difference_scale <= 0 || is.na(difference_scale) || is.infinite(difference_scale)) + # Check prior set-up for the interaction differences -------------------------- + if(difference_scale <= 0 || is.na(difference_scale) || is.infinite(difference_scale)) { stop("The scale of the Cauchy prior for the differences needs to be positive.") + } - #Check prior set-up for the threshold parameters ----------------------------- - if(main_alpha <= 0 | !is.finite(main_alpha)) + # Check prior set-up for the threshold parameters ----------------------------- + if(main_alpha <= 0 | !is.finite(main_alpha)) { stop("Parameter main_alpha needs to be positive.") - if(main_beta <= 0 | !is.finite(main_beta)) + } + if(main_beta <= 0 | !is.finite(main_beta)) { stop("Parameter main_beta needs to be positive.") + } - #Check set-up for the Bayesian difference selection model -------------------- + # Check set-up for the Bayesian difference selection model -------------------- difference_selection = as.logical(difference_selection) - if(is.na(difference_selection)) + if(is.na(difference_selection)) { stop("The parameter difference_selection needs to be TRUE or FALSE.") + } if(difference_selection == TRUE) { inclusion_probability_difference = matrix(0, - nrow = ncol(x), - ncol = ncol(x)) + nrow = ncol(x), + ncol = ncol(x) + ) difference_prior = match.arg(difference_prior) if(difference_prior == "Bernoulli") { if(length(difference_probability) == 1) { difference_inclusion_probability = difference_probability[1] - if(is.na(difference_inclusion_probability) || is.null(difference_inclusion_probability)) + if(is.na(difference_inclusion_probability) || is.null(difference_inclusion_probability)) { stop("There is no value specified for the inclusion probability for the differences.") - if(difference_inclusion_probability <= 0) + } + if(difference_inclusion_probability <= 0) { stop("The inclusion probability for differences needs to be positive.") - if(difference_inclusion_probability >= 1) + } + if(difference_inclusion_probability >= 1) { stop("The inclusion probability for differences cannot equal or exceed the value one.") + } - inclusion_probability_difference = matrix(difference_probability, - nrow = ncol(x), - ncol = ncol(x)) - + inclusion_probability_difference = matrix(difference_probability, + nrow = ncol(x), + ncol = ncol(x) + ) } else { if(!inherits(difference_probability, what = "matrix") && - !inherits(difference_probability, what = "data.frame")) + !inherits(difference_probability, what = "data.frame")) { stop("The input for the inclusion probability argument for differences needs to be a single number, matrix, or dataframe.") + } if(inherits(difference_probability, what = "data.frame")) { inclusion_probability_difference = data.matrix(difference_probability) @@ -448,31 +571,42 @@ check_compare_model = function( inclusion_probability_difference = difference_probability } - if(!isSymmetric(inclusion_probability_difference)) + if(!isSymmetric(inclusion_probability_difference)) { stop("The inclusion probability matrix needs to be symmetric.") - if(ncol(inclusion_probability_difference) != ncol(x)) - stop(paste0("The inclusion probability matrix needs to have as many rows (columns) as there\n", - " are variables in the data.")) + } + if(ncol(inclusion_probability_difference) != ncol(x)) { + stop(paste0( + "The inclusion probability matrix needs to have as many rows (columns) as there\n", + " are variables in the data." + )) + } if(anyNA(inclusion_probability_difference[lower.tri(inclusion_probability_difference, diag = TRUE)]) || - any(is.null(inclusion_probability_difference[lower.tri(inclusion_probability_difference, diag = TRUE)]))) + any(is.null(inclusion_probability_difference[lower.tri(inclusion_probability_difference, diag = TRUE)]))) { stop("One or more inclusion probabilities for differences are not specified.") - if(any(inclusion_probability_difference[lower.tri(inclusion_probability_difference, diag = TRUE)] <= 0)) + } + if(any(inclusion_probability_difference[lower.tri(inclusion_probability_difference, diag = TRUE)] <= 0)) { stop("One or more inclusion probabilities for differences are negative or zero.") - if(any(inclusion_probability_difference[lower.tri(inclusion_probability_difference, diag = TRUE)] >= 1)) + } + if(any(inclusion_probability_difference[lower.tri(inclusion_probability_difference, diag = TRUE)] >= 1)) { stop("One or more inclusion probabilities for differences are one or larger.") + } } } else { inclusion_probability_difference = matrix(0.5, - nrow = ncol(x), - ncol = ncol(x)) - if(beta_bernoulli_alpha <= 0 || beta_bernoulli_beta <= 0) + nrow = ncol(x), + ncol = ncol(x) + ) + if(beta_bernoulli_alpha <= 0 || beta_bernoulli_beta <= 0) { stop("The scale parameters of the beta distribution for the differences need to be positive.") - if(!is.finite(beta_bernoulli_alpha) || !is.finite(beta_bernoulli_beta)) + } + if(!is.finite(beta_bernoulli_alpha) || !is.finite(beta_bernoulli_beta)) { stop("The scale parameters of the beta distribution for the differences need to be finite.") + } if(is.na(beta_bernoulli_alpha) || is.na(beta_bernoulli_beta) || - is.null(beta_bernoulli_alpha) || is.null(beta_bernoulli_beta)) + is.null(beta_bernoulli_alpha) || is.null(beta_bernoulli_beta)) { stop("The scale parameters of the beta distribution for the differences need to be specified.") + } } } else { difference_prior = "Not applicable" @@ -487,17 +621,18 @@ check_compare_model = function( baseline_category = baseline_category, difference_prior = difference_prior, inclusion_probability_difference = inclusion_probability_difference - ) + ) ) } progress_type_from_display_progress <- function(display_progress = c("per-chain", "total", "none")) { - if (is.logical(display_progress) && length(display_progress) == 1) { - if (is.na(display_progress)) + if(is.logical(display_progress) && length(display_progress) == 1) { + if(is.na(display_progress)) { stop("The display_progress argument must be a single logical value, but not NA.") - display_progress = if (display_progress) "per-chain" else "none" + } + display_progress = if(display_progress) "per-chain" else "none" } else { display_progress = match.arg(display_progress) } - return(if (display_progress == "per-chain") 2L else if (display_progress == "total") 1L else 0L) + return(if(display_progress == "per-chain") 2L else if(display_progress == "total") 1L else 0L) } diff --git a/R/mcmc_summary.R b/R/mcmc_summary.R index 5ea5073..e240d88 100644 --- a/R/mcmc_summary.R +++ b/R/mcmc_summary.R @@ -7,7 +7,7 @@ combine_chains = function(fit, component) { niter = nrow(samples_list[[1]]) nparam = ncol(samples_list[[1]]) array3d = array(NA_real_, dim = c(niter, nchains, nparam)) - for (i in seq_len(nchains)) { + for(i in seq_len(nchains)) { array3d[, i, ] = samples_list[[i]] } array3d @@ -15,19 +15,22 @@ combine_chains = function(fit, component) { # Compute effective sample size and Rhat (Gelman-Rubin diagnostic) compute_rhat_ess = function(draws) { - tryCatch({ - if (is.matrix(draws) && ncol(draws) > 1) { - mcmc_list = coda::mcmc.list( - lapply(seq_len(ncol(draws)), function(i) coda::mcmc(draws[, i])) - ) - ess = coda::effectiveSize(mcmc_list) - rhat = coda::gelman.diag(mcmc_list, autoburnin = FALSE)$psrf[1] - } else { - ess = coda::effectiveSize(draws) - rhat = NA_real_ - } - list(ess = ess, rhat = rhat) - }, error = function(e) list(ess = NA_real_, rhat = NA_real_)) + tryCatch( + { + if(is.matrix(draws) && ncol(draws) > 1) { + mcmc_list = coda::mcmc.list( + lapply(seq_len(ncol(draws)), function(i) coda::mcmc(draws[, i])) + ) + ess = coda::effectiveSize(mcmc_list) + rhat = coda::gelman.diag(mcmc_list, autoburnin = FALSE)$psrf[1] + } else { + ess = coda::effectiveSize(draws) + rhat = NA_real_ + } + list(ess = ess, rhat = rhat) + }, + error = function(e) list(ess = NA_real_, rhat = NA_real_) + ) } # Basic summarizer for continuous parameters @@ -39,7 +42,7 @@ summarize_manual = function(fit, component = c("main_samples", "pairwise_samples result = matrix(NA, nparam, 5) colnames(result) = c("mean", "mcse", "sd", "n_eff", "Rhat") - for (j in seq_len(nparam)) { + for(j in seq_len(nparam)) { draws = array3d[, , j] vec = as.vector(draws) result[j, "mean"] = mean(vec) @@ -55,7 +58,6 @@ summarize_manual = function(fit, component = c("main_samples", "pairwise_samples } else { data.frame(parameter = param_names, result, check.names = FALSE) } - } # Summarize binary indicator variables @@ -70,7 +72,7 @@ summarize_indicator = function(fit, component = c("indicator_samples"), param_na result = matrix(NA, nparam, 9) colnames(result) = c("mean", "sd", "mcse", "n0->0", "n0->1", "n1->0", "n1->1", "n_eff", "Rhat") - for (j in seq_len(nparam)) { + for(j in seq_len(nparam)) { draws = array3d[, , j] vec = as.vector(draws) T = length(vec) @@ -84,7 +86,7 @@ summarize_indicator = function(fit, component = c("indicator_samples"), param_na n10 = sum(g_curr == 1 & g_next == 0) n11 = sum(g_curr == 1 & g_next == 1) - if (n01 + n10 == 0) { + if(n01 + n10 == 0) { n_eff = mcse = R = NA_real_ } else { a = n01 / (n00 + n01) @@ -102,8 +104,10 @@ summarize_indicator = function(fit, component = c("indicator_samples"), param_na if(is.null(param_names)) { data.frame(parameter = paste0("indicator [", seq_len(nparam), "]"), result, check.names = FALSE) } else { - data.frame(parameter = paste0(param_names, "- indicator"), - result, check.names = FALSE) + data.frame( + parameter = paste0(param_names, "- indicator"), + result, check.names = FALSE + ) } } @@ -115,17 +119,17 @@ summarize_slab = function(fit, component = c("pairwise_samples"), param_names = result = matrix(NA, nparam, 5) colnames(result) = c("mean", "sd", "mcse", "n_eff", "Rhat") - for (j in seq_len(nparam)) { + for(j in seq_len(nparam)) { draws = array3d[, , j] vec = as.vector(draws) nonzero = vec != 0 vec = vec[nonzero] T = length(vec) - if (T > 10) { + if(T > 10) { eap = mean(vec) sdev = sd(vec) - est = compute_rhat_ess(vec) ##draws + est = compute_rhat_ess(vec) ## draws mcse = sdev / sqrt(est$ess) result[j, ] = c(eap, sdev, mcse, est$ess, est$rhat) } @@ -134,17 +138,18 @@ summarize_slab = function(fit, component = c("pairwise_samples"), param_names = if(is.null(param_names)) { data.frame(parameter = paste0("weight [", seq_len(nparam), "]"), result, check.names = FALSE) } else { - data.frame(parameter = paste0(param_names, "- weight"), - result, check.names = FALSE) + data.frame( + parameter = paste0(param_names, "- weight"), + result, check.names = FALSE + ) } } # Combined summary for pairwise parameters with selection summarize_pair = function(fit, - indicator_component = c("indicator_samples"), - slab_component = c("pairwise_samples"), - param_names = NULL -) { + indicator_component = c("indicator_samples"), + slab_component = c("pairwise_samples"), + param_names = NULL) { indicator_component = match.arg(indicator_component) # Add options later slab_component = match.arg(slab_component) # Add options later @@ -165,13 +170,13 @@ summarize_pair = function(fit, nchains = dim(array3d_pw)[2] T = prod(dim(array3d_pw)[1:2]) - for (j in seq_len(nparam)) { + for(j in seq_len(nparam)) { draws_pw = array3d_pw[, , j] draws_id = array3d_id[, , j] - if (nchains > 1) { + if(nchains > 1) { chain_means = numeric(nchains) chain_vars = numeric(nchains) - for (chain in 1:nchains) { + for(chain in 1:nchains) { pi = mean(draws_id[, chain]) tmp = draws_pw[, chain] e = mean(tmp[tmp != 0]) @@ -186,18 +191,24 @@ summarize_pair = function(fit, } } - data.frame(parameter = paste0("V[", seq_len(nparam), "]"), - mean = eap, sd = sd, mcse = mcse, n_eff = n_eff, Rhat = rhat, - check.names = FALSE) + data.frame( + parameter = paste0("V[", seq_len(nparam), "]"), + mean = eap, sd = sd, mcse = mcse, n_eff = n_eff, Rhat = rhat, + check.names = FALSE + ) if(is.null(param_names)) { - data.frame(parameter = paste0("weight [", seq_len(nparam), "]"), - mean = eap, sd = sd, mcse = mcse, n_eff = n_eff, Rhat = rhat, - check.names = FALSE) + data.frame( + parameter = paste0("weight [", seq_len(nparam), "]"), + mean = eap, sd = sd, mcse = mcse, n_eff = n_eff, Rhat = rhat, + check.names = FALSE + ) } else { - data.frame(parameter = paste0(param_names, "- weight"), - mean = eap, sd = sd, mcse = mcse, n_eff = n_eff, Rhat = rhat, - check.names = FALSE) + data.frame( + parameter = paste0(param_names, "- weight"), + mean = eap, sd = sd, mcse = mcse, n_eff = n_eff, Rhat = rhat, + check.names = FALSE + ) } } @@ -205,7 +216,7 @@ summarize_pair = function(fit, summarize_fit = function(fit, edge_selection = FALSE) { main_summary = summarize_manual(fit, component = "main_samples") - if (!edge_selection) { + if(!edge_selection) { pair_summary = summarize_manual(fit, component = "pairwise_samples") } else { # Get indicators and slab draws @@ -219,12 +230,13 @@ summarize_fit = function(fit, edge_selection = FALSE) { # Use summarize_pair only where not always selected full_summary = summarize_pair(fit, - indicator_component = "indicator_samples", - slab_component = "pairwise_samples") + indicator_component = "indicator_samples", + slab_component = "pairwise_samples" + ) manual_summary = summarize_manual(fit, component = "pairwise_samples") # Replace rows in full_summary with manual results for fully selected entries - if (any(all_selected)) { + if(any(all_selected)) { full_summary[all_selected, ] = manual_summary[all_selected, ] } @@ -239,17 +251,17 @@ summarize_fit = function(fit, edge_selection = FALSE) { # Calculate convergence diagnostics on the pairwise cluster co-appearance values summarize_alloc_pairs = function(allocations, node_names = NULL) { - #stopifnot(is.list(allocations), length(allocations) >= 2) - n_ch = length(allocations) + # stopifnot(is.list(allocations), length(allocations) >= 2) + n_ch = length(allocations) n_iter = nrow(allocations[[1]]) - no_variables = ncol(allocations[[1]]) - for (c in seq_len(n_ch)) { + no_variables = ncol(allocations[[1]]) + for(c in seq_len(n_ch)) { stopifnot(nrow(allocations[[c]]) == n_iter, ncol(allocations[[c]]) == no_variables) } - if (!is.null(node_names)) stopifnot(length(node_names) == no_variables) + if(!is.null(node_names)) stopifnot(length(node_names) == no_variables) # all node pairs - Pairs = t(combn(seq_len(no_variables), 2)) + Pairs = t(combn(seq_len(no_variables), 2)) nparam = nrow(Pairs) result = matrix(NA, nparam, 9) @@ -258,15 +270,16 @@ summarize_alloc_pairs = function(allocations, node_names = NULL) { # helper to construct a "time-series" get_draws_pair = function(i, j) { out = matrix(NA, n_iter, n_ch) - for (c in seq_len(n_ch)) { + for(c in seq_len(n_ch)) { Zc = allocations[[c]] out[, c] = as.integer(Zc[, i] == Zc[, j]) } out } - for (p in seq_len(nparam)) { - i = Pairs[p, 1]; j = Pairs[p, 2] + for(p in seq_len(nparam)) { + i = Pairs[p, 1] + j = Pairs[p, 2] draws = get_draws_pair(i, j) vec = as.vector(draws) @@ -281,7 +294,7 @@ summarize_alloc_pairs = function(allocations, node_names = NULL) { n10 = sum(g_curr == 1 & g_next == 0) n11 = sum(g_curr == 1 & g_next == 1) - if (n01 + n10 == 0) { + if(n01 + n10 == 0) { n_eff = mcse = R = NA_real_ } else { a = n01 / (n00 + n01) @@ -295,11 +308,11 @@ summarize_alloc_pairs = function(allocations, node_names = NULL) { result[p, ] = c(p_hat, sd, mcse, n00, n01, n10, n11, n_eff, R) } - if (is.null(node_names)) { - rn = paste0(Pairs[,1], "-", Pairs[,2]) + if(is.null(node_names)) { + rn = paste0(Pairs[, 1], "-", Pairs[, 2]) dimn = as.character(seq_len(no_variables)) } else { - rn = paste0(node_names[Pairs[,1]], "-", node_names[Pairs[,2]]) + rn = paste0(node_names[Pairs[, 1]], "-", node_names[Pairs[, 2]]) dimn = node_names } @@ -307,11 +320,14 @@ summarize_alloc_pairs = function(allocations, node_names = NULL) { rownames(sbm_summary) = rn # construct the co-appearance matrix - co_occur_matrix = matrix(0, nrow = no_variables, ncol = no_variables, - dimnames = list(dimn, dimn)) + co_occur_matrix = matrix(0, + nrow = no_variables, ncol = no_variables, + dimnames = list(dimn, dimn) + ) diag(co_occur_matrix) = 1 - for (p in seq_len(nparam)) { - i = Pairs[p, 1]; j = Pairs[p, 2] + for(p in seq_len(nparam)) { + i = Pairs[p, 1] + j = Pairs[p, 2] m = sbm_summary[p, "mean"] co_occur_matrix[i, j] = m co_occur_matrix[j, i] = m @@ -330,7 +346,7 @@ summarize_alloc_pairs = function(allocations, node_names = NULL) { find_representative_clustering = function(cluster_matrix) { stopifnot(is.matrix(cluster_matrix)) n_iter = nrow(cluster_matrix) - p = ncol(cluster_matrix) + p = ncol(cluster_matrix) # Build co-clustering (membership) matrices for all iterations @@ -350,7 +366,7 @@ find_representative_clustering = function(cluster_matrix) { # MODE representative hash_mat = function(M) paste(as.integer(t(M)), collapse = ",") keys = vapply(Ms, hash_mat, character(1)) - tab = table(keys) + tab = table(keys) key_mode = names(tab)[which.max(tab)] idx_mode = match(key_mode, keys) alloc_mode = cluster_matrix[idx_mode, , drop = TRUE] @@ -374,11 +390,12 @@ find_representative_clustering = function(cluster_matrix) { # DOI:10.1080/01621459.2016.1255636 #' @importFrom stats dpois compute_p_k_given_t = function( - t, - log_Vn, - dirichlet_alpha, - num_variables, - lambda) { + t, + log_Vn, + dirichlet_alpha, + num_variables, + lambda +) { # Define the K_values K_values = as.numeric(1:num_variables) @@ -393,9 +410,9 @@ compute_p_k_given_t = function( truncated_poisson_pmf = dpois(K_values, lambda) / norm_factor # Loop through each value of K - for (i in seq_along(K_values)) { + for(i in seq_along(K_values)) { K = K_values[i] - if (K >= t) { + if(K >= t) { # Falling factorial falling_factorial = prod(K:(K - t + 1)) # Rising factorial @@ -417,9 +434,9 @@ compute_p_k_given_t = function( # Wrapper function to compute the posterior summary for the Stochastic Block Model posterior_summary_SBM = function( - allocations, - arguments) { - + allocations, + arguments +) { # combine the allocations from the chains cluster_allocations = do.call(rbind, allocations) @@ -429,7 +446,8 @@ posterior_summary_SBM = function( # Pre-compute log_Vn for computing the cluster probabilities log_Vn = compute_Vn_mfm_sbm( - num_variables, dirichlet_alpha, num_variables + 10, lambda) + num_variables, dirichlet_alpha, num_variables + 10, lambda + ) # Compute the number of unique clusters (t) for each iteration, i.e., the # cardinality of the partition z @@ -439,16 +457,17 @@ posterior_summary_SBM = function( # row in clusters p_k_given_t = matrix(NA, nrow = length(clusters), ncol = num_variables) - for (i in 1:length(clusters)) { + for(i in 1:length(clusters)) { p_k_given_t[i, ] = compute_p_k_given_t( - clusters[i], log_Vn, dirichlet_alpha, num_variables, lambda) + clusters[i], log_Vn, dirichlet_alpha, num_variables, lambda + ) } # Average across all iterations p_k_given_t = colMeans(p_k_given_t) # Format the output - #num_blocks = 1:num_variables + # num_blocks = 1:num_variables blocks = cbind(p_k_given_t) colnames(blocks) = c("probability") @@ -458,9 +477,11 @@ posterior_summary_SBM = function( # Compute the mean and mode of the allocations allocations = find_representative_clustering(cluster_allocations) - return(list(blocks = blocks, - allocations_mean = allocations$mean, - allocations_mode = allocations$mode)) + return(list( + blocks = blocks, + allocations_mean = allocations$mean, + allocations_mode = allocations$mode + )) } @@ -471,7 +492,7 @@ combine_chains_compare = function(fit, component) { niter = nrow(samples_list[[1]]) nparam = ncol(samples_list[[1]]) array3d = array(NA_real_, dim = c(niter, nchains, nparam)) - for (i in seq_len(nchains)) { + for(i in seq_len(nchains)) { array3d[, i, ] = samples_list[[i]] } array3d @@ -484,7 +505,7 @@ summarize_manual_compare = function(fit_or_array, component = match.arg(component) # allow either a fit list or a pre-combined 3D array - if (is.array(fit_or_array)) { + if(is.array(fit_or_array)) { array3d = fit_or_array } else { array3d = combine_chains_compare(fit_or_array, component) @@ -494,18 +515,18 @@ summarize_manual_compare = function(fit_or_array, result = matrix(NA, nparam, 5) colnames(result) = c("mean", "mcse", "sd", "n_eff", "Rhat") - for (j in seq_len(nparam)) { + for(j in seq_len(nparam)) { draws = array3d[, , j] - vec = as.vector(draws) + vec = as.vector(draws) result[j, "mean"] = mean(vec) - result[j, "sd"] = sd(vec) + result[j, "sd"] = sd(vec) est = compute_rhat_ess(draws) result[j, "mcse"] = sd(vec) / sqrt(est$ess) result[j, "n_eff"] = est$ess - result[j, "Rhat"] = est$rhat + result[j, "Rhat"] = est$rhat } - if (is.null(param_names)) { + if(is.null(param_names)) { data.frame(parameter = paste0("param [", seq_len(nparam)), result, check.names = FALSE) } else { data.frame(parameter = param_names, result, check.names = FALSE) @@ -513,8 +534,6 @@ summarize_manual_compare = function(fit_or_array, } - - summarize_indicator_compare = function(fit, component = "indicator_samples", param_names = NULL) { array3d = combine_chains_compare(fit, component) nparam = dim(array3d)[3] @@ -522,7 +541,7 @@ summarize_indicator_compare = function(fit, component = "indicator_samples", par result = matrix(NA, nparam, 9) colnames(result) = c("mean", "sd", "mcse", "n0->0", "n0->1", "n1->0", "n1->1", "n_eff", "Rhat") - for (j in seq_len(nparam)) { + for(j in seq_len(nparam)) { draws = array3d[, , j] vec = as.vector(draws) T = length(vec) @@ -536,7 +555,7 @@ summarize_indicator_compare = function(fit, component = "indicator_samples", par n10 = sum(g_curr == 1 & g_next == 0) n11 = sum(g_curr == 1 & g_next == 1) - if (n01 + n10 == 0) { + if(n01 + n10 == 0) { n_eff = mcse = R = NA_real_ } else { a = n01 / (n00 + n01) @@ -551,7 +570,7 @@ summarize_indicator_compare = function(fit, component = "indicator_samples", par result[j, ] = c(p_hat, sd, mcse, n00, n01, n10, n11, n_eff, R) } - if (is.null(param_names)) { + if(is.null(param_names)) { data.frame(parameter = paste0("indicator [", seq_len(nparam), "]"), result, check.names = FALSE) } else { data.frame(parameter = param_names, result, check.names = FALSE) @@ -562,7 +581,7 @@ summarize_indicator_compare = function(fit, component = "indicator_samples", par # Summarize one effect with spike-and-slab draws summarize_mixture_effect = function(draws_pw, draws_id, name) { nchains <- ncol(draws_pw) - niter <- nrow(draws_pw) + niter <- nrow(draws_pw) ## --- slab part --- vec <- as.vector(draws_pw) @@ -570,10 +589,10 @@ summarize_mixture_effect = function(draws_pw, draws_id, name) { vec <- vec[nonzero] T_slab <- length(vec) - if (T_slab > 10) { + if(T_slab > 10) { eap_slab <- mean(vec) var_slab <- var(vec) - est_slab <- compute_rhat_ess(vec) # treat as single chain + est_slab <- compute_rhat_ess(vec) # treat as single chain ess_slab <- est_slab$ess mcse_slab <- sqrt(var_slab) / sqrt(ess_slab) rhat_slab <- est_slab$rhat @@ -592,15 +611,15 @@ summarize_mixture_effect = function(draws_pw, draws_id, name) { g_curr <- vec_id[-T_id] p_hat <- mean(vec_id) - p_sd <- sqrt(p_hat * (1 - p_hat)) + p_sd <- sqrt(p_hat * (1 - p_hat)) - if (T_id > 1) { + if(T_id > 1) { n00 <- sum(g_curr == 0 & g_next == 0) n01 <- sum(g_curr == 0 & g_next == 1) n10 <- sum(g_curr == 1 & g_next == 0) n11 <- sum(g_curr == 1 & g_next == 1) - if (n01 + n10 == 0) { + if(n01 + n10 == 0) { p_mcse <- NA_real_ } else { a <- n01 / (n00 + n01) @@ -620,41 +639,41 @@ summarize_mixture_effect = function(draws_pw, draws_id, name) { mcse2 <- (eap_slab^2 * p_mcse^2) + (p_hat^2 * mcse_slab^2) - mcse <- if (is.finite(mcse2) && mcse2 > 0) sqrt(mcse2) else NA_real_ - n_eff <- if (!is.na(mcse) && mcse > 0) v / (mcse^2) else NA_real_ + mcse <- if(is.finite(mcse2) && mcse2 > 0) sqrt(mcse2) else NA_real_ + n_eff <- if(!is.na(mcse) && mcse > 0) v / (mcse^2) else NA_real_ ## --- Rhat (mixture, across chains) --- Rhat <- NA_real_ - if (nchains > 1) { + if(nchains > 1) { chain_means <- numeric(nchains) - chain_vars <- numeric(nchains) - for (ch in seq_len(nchains)) { + chain_vars <- numeric(nchains) + for(ch in seq_len(nchains)) { pi_ch <- mean(draws_id[, ch]) - tmp <- draws_pw[, ch] + tmp <- draws_pw[, ch] nz_ch <- tmp != 0 - if (isTRUE(any(nz_ch))) { + if(isTRUE(any(nz_ch))) { e_ch <- mean(tmp[nz_ch], na.rm = TRUE) - v_ch <- if (sum(nz_ch, na.rm = TRUE) > 1) var(tmp[nz_ch], na.rm = TRUE) else 0 + v_ch <- if(sum(nz_ch, na.rm = TRUE) > 1) var(tmp[nz_ch], na.rm = TRUE) else 0 } else { e_ch <- 0 v_ch <- 0 } chain_means[ch] <- pi_ch * e_ch - chain_vars[ch] <- pi_ch * (v_ch + (1 - pi_ch) * e_ch^2) + chain_vars[ch] <- pi_ch * (v_ch + (1 - pi_ch) * e_ch^2) } B <- niter * sum((chain_means - posterior_mean)^2) / (nchains - 1) W <- mean(chain_vars) V <- (niter - 1) * W / niter + B / niter - if (W > 0) Rhat <- sqrt(V / W) + if(W > 0) Rhat <- sqrt(V / W) } data.frame( parameter = name, - mean = posterior_mean, - sd = posterior_sd, - mcse = mcse, - n_eff = n_eff, - Rhat = Rhat, + mean = posterior_mean, + sd = posterior_sd, + mcse = mcse, + n_eff = n_eff, + Rhat = Rhat, check.names = FALSE ) } @@ -665,8 +684,8 @@ indicator_row_starts <- function(V) { # positions where each "row i" (i..V) starts in the flattened (i,j) list starts <- integer(V) starts[1L] <- 1L - if (V > 1L) { - for (i in 2L:V) { + if(V > 1L) { + for(i in 2L:V) { # previous row length = V - (i-1) + 1 starts[i] <- starts[i - 1L] + (V - (i - 1L) + 1L) } @@ -676,37 +695,37 @@ indicator_row_starts <- function(V) { summarize_main_diff_compare <- function( - fit, - main_effect_indices, - num_groups, - param_names = NULL + fit, + main_effect_indices, + num_groups, + param_names = NULL ) { main_effect_samples <- combine_chains_compare(fit, "main_samples") - indicator_samples <- combine_chains_compare(fit, "indicator_samples") + indicator_samples <- combine_chains_compare(fit, "indicator_samples") - V <- nrow(main_effect_indices) - num_main <- main_effect_indices[V, 2] + 1L # total rows in main-effects matrix - indicator_index_main <- function(i, V) indicator_row_starts(V)[i] + V <- nrow(main_effect_indices) + num_main <- main_effect_indices[V, 2] + 1L # total rows in main-effects matrix + indicator_index_main <- function(i, V) indicator_row_starts(V)[i] results <- list() counter <- 0L - for (v in seq_len(V)) { - id_idx <- indicator_index_main(v, V) # (v,v) position in flattened indicators + for(v in seq_len(V)) { + id_idx <- indicator_index_main(v, V) # (v,v) position in flattened indicators draws_id <- indicator_samples[, , id_idx] # rows in main-effects matrix belonging to variable v (1-based, inclusive) start <- main_effect_indices[v, 1] + 1L - stop <- main_effect_indices[v, 2] + 1L + stop <- main_effect_indices[v, 2] + 1L - for (row in start:stop) { + for(row in start:stop) { category <- row - start + 1L - for (h in 1L:(num_groups - 1L)) { - counter <- counter + 1L - col_index <- h * num_main + row # group-major blocks of length num_main - draws_pw <- main_effect_samples[, , col_index] + for(h in 1L:(num_groups - 1L)) { + counter <- counter + 1L + col_index <- h * num_main + row # group-major blocks of length num_main + draws_pw <- main_effect_samples[, , col_index] - pname <- if (!is.null(param_names)) { + pname <- if(!is.null(param_names)) { param_names[counter] } else { paste0("var", v, " (diff", h, "; ", category, ")") @@ -724,34 +743,34 @@ summarize_main_diff_compare <- function( summarize_pairwise_diff_compare <- function( - fit, - pairwise_effect_indices, - num_variables, - num_groups, - param_names = NULL + fit, + pairwise_effect_indices, + num_variables, + num_groups, + param_names = NULL ) { pairwise_effect_samples <- combine_chains_compare(fit, "pairwise_samples") - indicator_samples <- combine_chains_compare(fit, "indicator_samples") + indicator_samples <- combine_chains_compare(fit, "indicator_samples") - V <- num_variables - num_pair <- max(pairwise_effect_indices, na.rm = TRUE) + 1L # total rows in pairwise-effects matrix - indicator_index_pair <- function(i, j, V) indicator_row_starts(V)[i] + (j - i) # (i,j), i 0.001) { + if(divergence_rate > 0.001) { warning(sprintf( "About %.3f%% of transitions ended with a divergence (%d out of %d).\n", 100 * divergence_rate, total_divergences, nrow(divergent_mat) * ncol(divergent_mat) ), "Consider increasing the target acceptance rate.") - } else if (divergence_rate > 0) { - message(sprintf( - "Note: %.3f%% of transitions ended with a divergence (%d of %d).\n", - 100 * divergence_rate, - total_divergences, - nrow(divergent_mat) * ncol(divergent_mat) - ), - "Check R-hat and effective sample size (ESS) to ensure the chains are\n", - "mixing well.") + } else if(divergence_rate > 0) { + message( + sprintf( + "Note: %.3f%% of transitions ended with a divergence (%d of %d).\n", + 100 * divergence_rate, + total_divergences, + nrow(divergent_mat) * ncol(divergent_mat) + ), + "Check R-hat and effective sample size (ESS) to ensure the chains are\n", + "mixing well." + ) } depth_hit_rate <- max_tree_depth_hits / (nrow(treedepth_mat) * ncol(treedepth_mat)) - if (depth_hit_rate > 0.01) { + if(depth_hit_rate > 0.01) { warning(paste0( sprintf( "About %.2f%% of transitions hit the maximum tree depth (%d out of %d).\n", @@ -65,7 +67,7 @@ summarize_nuts_diagnostics <- function(out, nuts_max_depth = 10, verbose = TRUE) ), "Consider increasing max_depth." )) - } else if (depth_hit_rate > 0) { + } else if(depth_hit_rate > 0) { message(paste0( sprintf( "Note: %.2f%% of transitions hit the maximum tree depth (%d of %d).\n", @@ -82,16 +84,17 @@ summarize_nuts_diagnostics <- function(out, nuts_max_depth = 10, verbose = TRUE) low_ebfmi_chains <- which(ebfmi_per_chain < 0.3) min_ebfmi <- min(ebfmi_per_chain) - if (length(low_ebfmi_chains) > 0) { - warning(sprintf( - "E-BFMI below 0.3 detected in %d chain(s): %s.\n", - length(low_ebfmi_chains), - paste(low_ebfmi_chains, collapse = ", ") - ), - "This suggests inefficient momentum resampling in those chains.\n", - "Sampling efficiency may be reduced. Consider longer chains or checking convergence diagnostics.") + if(length(low_ebfmi_chains) > 0) { + warning( + sprintf( + "E-BFMI below 0.3 detected in %d chain(s): %s.\n", + length(low_ebfmi_chains), + paste(low_ebfmi_chains, collapse = ", ") + ), + "This suggests inefficient momentum resampling in those chains.\n", + "Sampling efficiency may be reduced. Consider longer chains or checking convergence diagnostics." + ) } - } # Return structured summary diff --git a/R/output_utils.R b/R/output_utils.R index 37f3ab4..264b7e2 100644 --- a/R/output_utils.R +++ b/R/output_utils.R @@ -1,11 +1,11 @@ prepare_output_bgm = function( - out, x, num_categories, iter, data_columnnames, is_ordinal_variable, - warmup, pairwise_scale, main_alpha, main_beta, - na_action, na_impute, edge_selection, edge_prior, inclusion_probability, - beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, - beta_bernoulli_beta_between,dirichlet_alpha, lambda, - variable_type, update_method, target_accept, hmc_num_leapfrogs, - nuts_max_depth, learn_mass_matrix, num_chains + out, x, num_categories, iter, data_columnnames, is_ordinal_variable, + warmup, pairwise_scale, main_alpha, main_beta, + na_action, na_impute, edge_selection, edge_prior, inclusion_probability, + beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, + beta_bernoulli_beta_between, dirichlet_alpha, lambda, + variable_type, update_method, target_accept, hmc_num_leapfrogs, + nuts_max_depth, learn_mass_matrix, num_chains ) { arguments = list( prepared_data = x, @@ -37,7 +37,7 @@ prepare_output_bgm = function( num_chains = num_chains, num_categories = num_categories, data_columnnames = data_columnnames, - no_variables = ncol(x) #backwards compatibility easybgm + no_variables = ncol(x) # backwards compatibility easybgm ) num_variables = ncol(x) @@ -45,8 +45,8 @@ prepare_output_bgm = function( # ======= Parameter name generation ======= names_variable_categories = character() - for (v in seq_len(num_variables)) { - if (is_ordinal_variable[v]) { + for(v in seq_len(num_variables)) { + if(is_ordinal_variable[v]) { cats = seq_len(num_categories[v]) names_variable_categories = c( names_variable_categories, @@ -62,8 +62,8 @@ prepare_output_bgm = function( } edge_names = character() - for (i in 1:(num_variables - 1)) { - for (j in (i + 1):num_variables) { + for(i in 1:(num_variables - 1)) { + for(j in (i + 1):num_variables) { edge_names = c(edge_names, paste0(data_columnnames[i], "-", data_columnnames[j])) } } @@ -79,11 +79,11 @@ prepare_output_bgm = function( results$posterior_summary_main = main_summary results$posterior_summary_pairwise = pairwise_summary - if (edge_selection) { + if(edge_selection) { indicator_summary = summarize_indicator(out, param_names = edge_names)[, -1] rownames(indicator_summary) = edge_names results$posterior_summary_indicator = indicator_summary - if (identical(edge_prior, "Stochastic-Block") && "allocations" %in% names(out[[1]])) { + if(identical(edge_prior, "Stochastic-Block") && "allocations" %in% names(out[[1]])) { # convergence diagnostics of the co-apperance of the nodes sbm_convergence = summarize_alloc_pairs( allocations = lapply(out, `[[`, "allocations"), @@ -116,7 +116,7 @@ prepare_output_bgm = function( results$posterior_mean_main = pmm rownames(results$posterior_mean_main) = data_columnnames - colnames(results$posterior_mean_main) = paste0("cat (",1:ncol(pmm), ")") + colnames(results$posterior_mean_main) = paste0("cat (", 1:ncol(pmm), ")") results$posterior_mean_pairwise = matrix(0, nrow = num_variables, ncol = num_variables) results$posterior_mean_pairwise[lower.tri(results$posterior_mean_pairwise)] = pairwise_summary$mean @@ -124,7 +124,7 @@ prepare_output_bgm = function( rownames(results$posterior_mean_pairwise) = data_columnnames colnames(results$posterior_mean_pairwise) = data_columnnames - if (edge_selection) { + if(edge_selection) { indicator_means = indicator_summary$mean results$posterior_mean_indicator = matrix(0, nrow = num_variables, ncol = num_variables) results$posterior_mean_indicator[upper.tri(results$posterior_mean_indicator)] = indicator_means @@ -133,7 +133,7 @@ prepare_output_bgm = function( rownames(results$posterior_mean_indicator) = data_columnnames colnames(results$posterior_mean_indicator) = data_columnnames - if (identical(edge_prior, "Stochastic-Block") && "allocations" %in% names(out[[1]])) { + if(identical(edge_prior, "Stochastic-Block") && "allocations" %in% names(out[[1]])) { # convergence diagnostics of the co-apperance of the nodes sbm_convergence = summarize_alloc_pairs( allocations = lapply(out, `[[`, "allocations"), @@ -141,8 +141,10 @@ prepare_output_bgm = function( ) results$posterior_mean_coclustering_matrix = sbm_convergence$co_occur_matrix # calculate the estimated clustering and block probabilities - sbm_summary = posterior_summary_SBM(allocations = lapply(out, `[[`, "allocations"), - arguments = arguments) # check if only arguments would work + sbm_summary = posterior_summary_SBM( + allocations = lapply(out, `[[`, "allocations"), + arguments = arguments + ) # check if only arguments would work # extract the posterior mean and median results$posterior_mean_allocations = sbm_summary$allocations_mean results$posterior_mode_allocations = sbm_summary$allocations_mode @@ -157,17 +159,17 @@ prepare_output_bgm = function( class(results) = "bgms" results$raw_samples = list( - main = lapply(out, function(chain) chain$main_samples), - pairwise = lapply(out, function(chain) chain$pairwise_samples), - indicator = if (edge_selection) lapply(out, function(chain) chain$indicator_samples) else NULL, - allocations = if (edge_selection && identical(edge_prior, "Stochastic-Block") && "allocations" %in% names(out[[1]])) lapply(out, `[[`, "allocations") else NULL, - nchains = length(out), - niter = nrow(out[[1]]$main_samples), + main = lapply(out, function(chain) chain$main_samples), + pairwise = lapply(out, function(chain) chain$pairwise_samples), + indicator = if(edge_selection) lapply(out, function(chain) chain$indicator_samples) else NULL, + allocations = if(edge_selection && identical(edge_prior, "Stochastic-Block") && "allocations" %in% names(out[[1]])) lapply(out, `[[`, "allocations") else NULL, + nchains = length(out), + niter = nrow(out[[1]]$main_samples), parameter_names = list( - main = names_variable_categories, + main = names_variable_categories, pairwise = edge_names, - indicator = if (edge_selection) edge_names else NULL, - allocations = if (identical(edge_prior, "Stochastic-Block")) edge_names else NULL + indicator = if(edge_selection) edge_names else NULL, + allocations = if(identical(edge_prior, "Stochastic-Block")) edge_names else NULL ) ) @@ -175,19 +177,18 @@ prepare_output_bgm = function( } - # Generate names for bgmCompare parameters generate_param_names_bgmCompare = function( - data_columnnames, - num_categories, - is_ordinal_variable, - num_variables, - num_groups + data_columnnames, + num_categories, + is_ordinal_variable, + num_variables, + num_groups ) { # --- main baselines names_main_baseline = character() - for (v in seq_len(num_variables)) { - if (is_ordinal_variable[v]) { + for(v in seq_len(num_variables)) { + if(is_ordinal_variable[v]) { cats = seq_len(num_categories[v]) names_main_baseline = c( names_main_baseline, @@ -204,9 +205,9 @@ generate_param_names_bgmCompare = function( # --- main differences names_main_diff = character() - for (g in 2:num_groups) { - for (v in seq_len(num_variables)) { - if (is_ordinal_variable[v]) { + for(g in 2:num_groups) { + for(v in seq_len(num_variables)) { + if(is_ordinal_variable[v]) { cats = seq_len(num_categories[v]) names_main_diff = c( names_main_diff, @@ -224,8 +225,8 @@ generate_param_names_bgmCompare = function( # --- pairwise baselines names_pairwise_baseline = character() - for (i in 1:(num_variables - 1)) { - for (j in (i + 1):num_variables) { + for(i in 1:(num_variables - 1)) { + for(j in (i + 1):num_variables) { names_pairwise_baseline = c( names_pairwise_baseline, paste0(data_columnnames[i], "-", data_columnnames[j]) @@ -235,9 +236,9 @@ generate_param_names_bgmCompare = function( # --- pairwise differences names_pairwise_diff = character() - for (g in 2:num_groups) { - for (i in 1:(num_variables - 1)) { - for (j in (i + 1):num_variables) { + for(g in 2:num_groups) { + for(i in 1:(num_variables - 1)) { + for(j in (i + 1):num_variables) { names_pairwise_diff = c( names_pairwise_diff, paste0(data_columnnames[i], "-", data_columnnames[j], " (diff", g - 1, ")") @@ -248,14 +249,14 @@ generate_param_names_bgmCompare = function( # --- indicators generate_indicator_names <- function(data_columnnames) { - V <- length(data_columnnames) + V <- length(data_columnnames) out <- character() - for (i in seq_len(V)) { + for(i in seq_len(V)) { # main (diagonal) out <- c(out, paste0(data_columnnames[i], " (main)")) # then all pairs with i as the first index - if (i < V) { - for (j in seq.int(i + 1L, V)) { + if(i < V) { + for(j in seq.int(i + 1L, V)) { out <- c(out, paste0(data_columnnames[i], "-", data_columnnames[j], " (pairwise)")) } } @@ -277,15 +278,15 @@ generate_param_names_bgmCompare = function( prepare_output_bgmCompare = function( - out, observations, num_categories, is_ordinal_variable, - num_groups, group, iter, warmup, - main_effect_indices, pairwise_effect_indices, - data_columnnames, difference_selection, - difference_prior, difference_selection_alpha, difference_selection_beta, - inclusion_probability, - pairwise_scale, difference_scale, - update_method, target_accept, nuts_max_depth, hmc_num_leapfrogs, - learn_mass_matrix, num_chains, projection + out, observations, num_categories, is_ordinal_variable, + num_groups, group, iter, warmup, + main_effect_indices, pairwise_effect_indices, + data_columnnames, difference_selection, + difference_prior, difference_selection_alpha, difference_selection_beta, + inclusion_probability, + pairwise_scale, difference_scale, + update_method, target_accept, nuts_max_depth, hmc_num_leapfrogs, + learn_mass_matrix, num_chains, projection ) { num_variables = ncol(observations) @@ -345,7 +346,7 @@ prepare_output_bgmCompare = function( posterior_summary_pairwise_differences = summary_list$pairwise_differences ) - if (difference_selection) { + if(difference_selection) { results$posterior_summary_indicator = summary_list$indicators } @@ -372,7 +373,7 @@ prepare_output_bgmCompare = function( results$posterior_mean_main_baseline = pmm rownames(results$posterior_mean_main_baseline) = data_columnnames - colnames(results$posterior_mean_main_baseline) = paste0("cat (",1:ncol(pmm), ")") + colnames(results$posterior_mean_main_baseline) = paste0("cat (", 1:ncol(pmm), ")") results$posterior_mean_pairwise_baseline = matrix(0, num_variables, num_variables) results$posterior_mean_pairwise_baseline[lower.tri(results$posterior_mean_pairwise_baseline)] = @@ -384,11 +385,11 @@ prepare_output_bgmCompare = function( # --- raw samples (like in prepare_output_bgm) results$raw_samples = list( - main = lapply(out, function(chain) chain$main_samples), - pairwise = lapply(out, function(chain) chain$pairwise_samples), - indicator = if (difference_selection) lapply(out, function(chain) chain$indicator_samples) else NULL, - nchains = length(out), - niter = nrow(out[[1]]$main_samples), + main = lapply(out, function(chain) chain$main_samples), + pairwise = lapply(out, function(chain) chain$pairwise_samples), + indicator = if(difference_selection) lapply(out, function(chain) chain$indicator_samples) else NULL, + nchains = length(out), + niter = nrow(out[[1]]$main_samples), parameter_names = names_all ) @@ -396,4 +397,3 @@ prepare_output_bgmCompare = function( class(results) = c("bgmCompare") results } - diff --git a/R/sampleMRF.R b/R/sampleMRF.R index f8eed6f..487c627 100644 --- a/R/sampleMRF.R +++ b/R/sampleMRF.R @@ -76,11 +76,13 @@ #' Interactions = Interactions + t(Interactions) #' Thresholds = matrix(0, nrow = no_variables, ncol = max(no_categories)) #' -#' x = mrfSampler(no_states = 1e3, -#' no_variables = no_variables, -#' no_categories = no_categories, -#' interactions = Interactions, -#' thresholds = Thresholds) +#' x = mrfSampler( +#' no_states = 1e3, +#' no_variables = no_variables, +#' no_categories = no_categories, +#' interactions = Interactions, +#' thresholds = Thresholds +#' ) #' #' # Generate responses from a network of 2 ordinal and 3 Blume-Capel variables. #' no_variables = 5 @@ -97,13 +99,15 @@ #' Thresholds[3, ] = sort(-abs(rnorm(4)), decreasing = TRUE) #' Thresholds[5, ] = sort(-abs(rnorm(4)), decreasing = TRUE) #' -#' x = mrfSampler(no_states = 1e3, -#' no_variables = no_variables, -#' no_categories = no_categories, -#' interactions = Interactions, -#' thresholds = Thresholds, -#' variable_type = c("b","b","o","b","o"), -#' reference_category = 2) +#' x = mrfSampler( +#' no_states = 1e3, +#' no_variables = no_variables, +#' no_categories = no_categories, +#' interactions = Interactions, +#' thresholds = Thresholds, +#' variable_type = c("b", "b", "o", "b", "o"), +#' reference_category = 2 +#' ) #' #' @export mrfSampler = function(no_states, @@ -116,54 +120,69 @@ mrfSampler = function(no_states, iter = 1e3) { # Check no_states, no_variables, iter -------------------------------------------- if(no_states <= 0 || - abs(no_states - round(no_states)) > .Machine$double.eps) + abs(no_states - round(no_states)) > .Machine$double.eps) { stop("``no_states'' needs be a positive integer.") + } if(no_variables <= 0 || - abs(no_variables - round(no_variables)) > .Machine$double.eps) + abs(no_variables - round(no_variables)) > .Machine$double.eps) { stop("``no_variables'' needs be a positive integer.") + } if(iter <= 0 || - abs(iter - round(iter)) > .Machine$double.eps) + abs(iter - round(iter)) > .Machine$double.eps) { stop("``iter'' needs be a positive integer.") + } # Check no_categories -------------------------------------------------------- if(length(no_categories) == 1) { if(no_categories <= 0 || - abs(no_categories - round(no_categories)) > .Machine$double.eps) + abs(no_categories - round(no_categories)) > .Machine$double.eps) { stop("``no_categories'' needs be a (vector of) positive integer(s).") + } no_categories = rep(no_categories, no_variables) } else { for(variable in 1:no_variables) { if(no_categories[variable] <= 0 || - abs(no_categories[variable] - round(no_categories[variable])) > - .Machine$double.eps) + abs(no_categories[variable] - round(no_categories[variable])) > + .Machine$double.eps) { stop(paste("For variable", variable, "``no_categories'' was not a positive integer.")) + } } } # Check variable specification ----------------------------------------------- if(length(variable_type) == 1) { - variable_type = match.arg(arg = variable_type, - choices = c("ordinal", "blume-capel")) + variable_type = match.arg( + arg = variable_type, + choices = c("ordinal", "blume-capel") + ) if(variable_type == "blume-capel" && any(no_categories < 2)) { - stop(paste0("The Blume-Capel model only works for ordinal variables with more than two \n", - "response options. But variables ", which(no_categories < 2), " are binary variables.")) + stop(paste0( + "The Blume-Capel model only works for ordinal variables with more than two \n", + "response options. But variables ", which(no_categories < 2), " are binary variables." + )) } variable_type = rep(variable_type, no_variables) - } else { if(length(variable_type) != no_variables) { - stop(paste0("The argument ``variable_type'' should be either a single character string or a \n", - "vector of character strings of length ``no_variables''.")) + stop(paste0( + "The argument ``variable_type'' should be either a single character string or a \n", + "vector of character strings of length ``no_variables''." + )) } else { - for(variable in 1:no_variables) - variable_type[variable] = match.arg(arg = variable_type[variable], - choices = c("ordinal", "blume-capel")) + for(variable in 1:no_variables) { + variable_type[variable] = match.arg( + arg = variable_type[variable], + choices = c("ordinal", "blume-capel") + ) + } if(any(variable_type == "blume-capel" & no_categories < 2)) { - stop(paste0("The Blume-Capel model only works for ordinal variables with more than two \n", - "response options. But variables ", - which(variable_type == "blume-capel" & no_categories < 2), - " are binary variables.")) + stop(paste0( + "The Blume-Capel model only works for ordinal variables with more than two \n", + "response options. But variables ", + which(variable_type == "blume-capel" & no_categories < 2), + " are binary variables." + )) } } } @@ -174,24 +193,31 @@ mrfSampler = function(no_states, reference_category = rep(reference_category, no_variables) } if(any(reference_category < 0) || any(abs(reference_category - round(reference_category)) > .Machine$double.eps)) { - stop(paste0("For variables ", - which(reference_category < 0), - " ``reference_category'' was either negative or not integer.")) + stop(paste0( + "For variables ", + which(reference_category < 0), + " ``reference_category'' was either negative or not integer." + )) } if(any(reference_category - no_categories > 0)) { - stop(paste0("For variables ", - which(reference_category - no_categories > 0), - " the ``reference_category'' category was larger than the maximum category value.")) + stop(paste0( + "For variables ", + which(reference_category - no_categories > 0), + " the ``reference_category'' category was larger than the maximum category value." + )) } } # Check interactions --------------------------------------------------------- - if(!inherits(interactions, what = "matrix")) + if(!inherits(interactions, what = "matrix")) { interactions = as.matrix(interactions) - if(!isSymmetric(interactions)) + } + if(!isSymmetric(interactions)) { stop("The matrix ``interactions'' needs to be symmetric.") - if(nrow(interactions) != no_variables) + } + if(nrow(interactions) != no_variables) { stop("The matrix ``interactions'' needs to have ``no_variables'' rows and columns.") + } # Check the threshold values ------------------------------------------------- if(!inherits(thresholds, what = "matrix")) { @@ -199,19 +225,22 @@ mrfSampler = function(no_states, if(length(thresholds) == no_variables) { thresholds = matrix(thresholds, ncol = 1) } else { - stop(paste0("The matrix ``thresholds'' has ", - length(thresholds), - " elements, but requires", - no_variables, - ".")) + stop(paste0( + "The matrix ``thresholds'' has ", + length(thresholds), + " elements, but requires", + no_variables, + "." + )) } } else { stop("``Thresholds'' needs to be a matrix.") } } - if(nrow(thresholds) != no_variables) + if(nrow(thresholds) != no_variables) { stop("The matrix ``thresholds'' needs to be have ``no_variables'' rows.") + } for(variable in 1:no_variables) { if(variable_type[variable] != "blume-capel") { @@ -220,88 +249,107 @@ mrfSampler = function(no_states, string = paste(tmp, sep = ",") - stop(paste0("The matrix ``thresholds'' contains NA(s) for variable ", - variable, - " in category \n", - "(categories) ", - paste(which(is.na(thresholds[variable, 1:no_categories[variable]])), collapse = ", "), - ", where a numeric value is needed.")) + stop(paste0( + "The matrix ``thresholds'' contains NA(s) for variable ", + variable, + " in category \n", + "(categories) ", + paste(which(is.na(thresholds[variable, 1:no_categories[variable]])), collapse = ", "), + ", where a numeric value is needed." + )) } if(ncol(thresholds) > no_categories[variable]) { - if(!anyNA(thresholds[variable, (no_categories[variable]+1):ncol(thresholds)])) { - warning(paste0("The matrix ``thresholds'' contains numeric values for variable ", - variable, - " for category \n", - "(categories, i.e., columns) exceding the maximum of ", - no_categories[variable], - ". These values will \n", - "be ignored.")) + if(!anyNA(thresholds[variable, (no_categories[variable] + 1):ncol(thresholds)])) { + warning(paste0( + "The matrix ``thresholds'' contains numeric values for variable ", + variable, + " for category \n", + "(categories, i.e., columns) exceding the maximum of ", + no_categories[variable], + ". These values will \n", + "be ignored." + )) } } } else { if(anyNA(thresholds[variable, 1:2])) { - stop(paste0("The Blume-Capel model is chosen for the category thresholds of variable ", - variable, - ". \n", - "This model has two parameters that need to be placed in columns 1 and 2, row \n", - variable, - ", of the ``thresholds'' input matrix. Currently, there are NA(s) in these \n", - "entries, where a numeric value is needed.")) + stop(paste0( + "The Blume-Capel model is chosen for the category thresholds of variable ", + variable, + ". \n", + "This model has two parameters that need to be placed in columns 1 and 2, row \n", + variable, + ", of the ``thresholds'' input matrix. Currently, there are NA(s) in these \n", + "entries, where a numeric value is needed." + )) } if(ncol(thresholds) > 2) { if(!anyNA(thresholds[variable, 3:ncol(thresholds)])) { - warning(paste0("The Blume-Capel model is chosen for the category thresholds of variable ", - variable, - ". \n", - "This model has two parameters that need to be placed in columns 1 and 2, row \n", - variable, - ", of the ``thresholds'' input matrix. However, there are numeric values \n", - "in higher categories. These values will be ignored.")) + warning(paste0( + "The Blume-Capel model is chosen for the category thresholds of variable ", + variable, + ". \n", + "This model has two parameters that need to be placed in columns 1 and 2, row \n", + variable, + ", of the ``thresholds'' input matrix. However, there are numeric values \n", + "in higher categories. These values will be ignored." + )) } } } - } for(variable in 1:no_variables) { if(variable_type[variable] != "blume-capel") { for(category in 1:no_categories[variable]) { - if(!is.finite(thresholds[variable, category])) - stop(paste("The threshold parameter for variable", variable, "and category", - category, "is NA or not finite.")) + if(!is.finite(thresholds[variable, category])) { + stop(paste( + "The threshold parameter for variable", variable, "and category", + category, "is NA or not finite." + )) + } } } else { - if(!is.finite(thresholds[variable, 1])) + if(!is.finite(thresholds[variable, 1])) { stop(paste0( "The alpha parameter for the Blume-Capel model for variable ", variable, " is NA \n", - " or not finite.")) - if(!is.finite(thresholds[variable, 2])) - stop(paste0("The beta parameter for the Blume-Capel model for variable", - variable, - "is NA \n", - " or not finite.")) + " or not finite." + )) + } + if(!is.finite(thresholds[variable, 2])) { + stop(paste0( + "The beta parameter for the Blume-Capel model for variable", + variable, + "is NA \n", + " or not finite." + )) + } } } # The Gibbs sampler ---------------------------------------------------------- if(!any(variable_type == "blume-capel")) { - x <- sample_omrf_gibbs(no_states = no_states, - no_variables = no_variables, - no_categories = no_categories, - interactions = interactions, - thresholds = thresholds, - iter = iter) + x <- sample_omrf_gibbs( + no_states = no_states, + no_variables = no_variables, + no_categories = no_categories, + interactions = interactions, + thresholds = thresholds, + iter = iter + ) } else { - x <- sample_bcomrf_gibbs(no_states = no_states, - no_variables = no_variables, - no_categories = no_categories, - interactions = interactions, - thresholds = thresholds, - variable_type = variable_type, - reference_category = reference_category, - iter = iter) + x <- sample_bcomrf_gibbs( + no_states = no_states, + no_variables = no_variables, + no_categories = no_categories, + interactions = interactions, + thresholds = thresholds, + variable_type = variable_type, + reference_category = reference_category, + iter = iter + ) } return(x) diff --git a/Readme.Rmd b/Readme.Rmd index 2af31f3..dbdcefa 100644 --- a/Readme.Rmd +++ b/Readme.Rmd @@ -103,11 +103,11 @@ This makes it possible not only to detect structure and group differences, but a The current developmental version can be installed with ```{r gh-installation, eval = FALSE} -if (!requireNamespace("remotes")) { - install.packages("remotes") -} +if(!requireNamespace("remotes")) { + install.packages("remotes") +} remotes::install_github("Bayesian-Graphical-Modelling-Lab/bgms") ``` -## References \ No newline at end of file +## References diff --git a/tests/testthat/test-bgm.R b/tests/testthat/test-bgm.R index 9cd6fb2..03fc786 100644 --- a/tests/testthat/test-bgm.R +++ b/tests/testthat/test-bgm.R @@ -4,20 +4,20 @@ test_that("Posterior means correlate with sufficient precision statistics", { fit = bgm(Wenchuan, edge_selection = FALSE, iter = 100, warmup = 1000, seed = 1234, chains = 1) x = Wenchuan x = na.omit(x) - alt = -solve(t(x)%*%x) + alt = -solve(t(x) %*% x) alt = alt[lower.tri(alt)] posterior_means = colMeans(extract_pairwise_interactions(fit)) testthat::expect_gte(cor(posterior_means, alt, method = "spearman"), .98) }) -on_ci <- isTRUE(as.logical(Sys.getenv("CI", "false"))) -no_cores <- if (on_ci) 2L else min(4, parallel::detectCores()) +on_ci <- isTRUE(as.logical(Sys.getenv("CI", "false"))) +no_cores <- if(on_ci) 2L else min(4, parallel::detectCores()) test_that("bgm is reproducible", { testthat::skip_on_cran() data("Wenchuan", package = "bgms") - x <- Wenchuan[1:50, 1:5] + x <- Wenchuan[1:50, 1:5] fit1 <- bgm(x = x, iter = 100, warmup = 1000, cores = no_cores, seed = 1234) fit2 <- bgm(x = x, iter = 100, warmup = 1000, cores = no_cores, seed = 1234) @@ -35,11 +35,11 @@ test_that("bgmCompare is reproducible", { combine_chains <- function(fit) { pairs <- fit$raw_samples$pairwise - pair <- do.call(rbind, pairs) + pair <- do.call(rbind, pairs) mains <- fit$raw_samples$main - main <- do.call(rbind, mains) + main <- do.call(rbind, mains) inds <- fit$raw_samples$indicator - ind <- do.call(rbind, inds) + ind <- do.call(rbind, inds) return(cbind(main, pair, ind)) } diff --git a/vignettes/comparison.Rmd b/vignettes/comparison.Rmd index 3f35f18..c1aef9f 100644 --- a/vignettes/comparison.Rmd +++ b/vignettes/comparison.Rmd @@ -42,7 +42,7 @@ data_english = data_english[, 1:5] # Fitting a model ```{r, include=FALSE} - fit <- readRDS(system.file("extdata", "fit_boredom.rds", package = "bgms")) +fit <- readRDS(system.file("extdata", "fit_boredom.rds", package = "bgms")) ``` ```{r, eval = FALSE} fit = bgmCompare(x = data_french, y = data_english, seed = 1234) @@ -79,12 +79,13 @@ colnames(french_network) = colnames(data_french) rownames(french_network) = colnames(data_french) qgraph(french_network, - theme = "TeamFortress", - maximum = 1, - fade = FALSE, - color = c("#f0ae0e"), vsize = 10, repulsion = .9, - label.cex = 1, label.scale = "FALSE", - labels = colnames(data_french)) + theme = "TeamFortress", + maximum = 1, + fade = FALSE, + color = c("#f0ae0e"), vsize = 10, repulsion = .9, + label.cex = 1, label.scale = "FALSE", + labels = colnames(data_french) +) ``` # Next steps @@ -93,4 +94,4 @@ qgraph(french_network, - For diagnostics and convergence checks, see the *Diagnostics* vignette. - For additional analysis tools and more advanced plotting options, consider using the **easybgm** package, which integrates smoothly with - **bgms** objects. \ No newline at end of file + **bgms** objects. diff --git a/vignettes/diagnostics.Rmd b/vignettes/diagnostics.Rmd index eb68766..d9e2499 100644 --- a/vignettes/diagnostics.Rmd +++ b/vignettes/diagnostics.Rmd @@ -69,8 +69,10 @@ param_index = 1 chains = lapply(fit$raw_samples$pairwise, function(mat) mat[, param_index]) mcmc_obj = mcmc.list(lapply(chains, mcmc)) -traceplot(mcmc_obj, col = c("firebrick", "steelblue", "darkgreen", "goldenrod"), - main = "Traceplot of pairwise[1]") +traceplot(mcmc_obj, + col = c("firebrick", "steelblue", "darkgreen", "goldenrod"), + main = "Traceplot of pairwise[1]" +) ``` diff --git a/vignettes/intro.Rmd b/vignettes/intro.Rmd index fd7843a..887336b 100644 --- a/vignettes/intro.Rmd +++ b/vignettes/intro.Rmd @@ -86,16 +86,17 @@ library(qgraph) median_probability_network = coef(fit)$pairwise median_probability_network[coef(fit)$indicator < 0.5] = 0.0 -qgraph(median_probability_network, - theme = "TeamFortress", - maximum = 1, - fade = FALSE, - color = c("#f0ae0e"), vsize = 10, repulsion = .9, - label.cex = 1, label.scale = "FALSE", - labels = colnames(data)) +qgraph(median_probability_network, + theme = "TeamFortress", + maximum = 1, + fade = FALSE, + color = c("#f0ae0e"), vsize = 10, repulsion = .9, + label.cex = 1, label.scale = "FALSE", + labels = colnames(data) +) ``` # Next steps - For comparing groups, see `?bgmCompare` or the *Model Comparison* vignette. -- For diagnostics and convergence checks, see the *Diagnostics* vignette. \ No newline at end of file +- For diagnostics and convergence checks, see the *Diagnostics* vignette.