Skip to content

Commit 33c11a3

Browse files
authored
style project (#69)
1 parent c925d92 commit 33c11a3

17 files changed

+1578
-1208
lines changed

R/bgm.R

Lines changed: 127 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -352,64 +352,65 @@
352352
#'
353353
#' @export
354354
bgm = function(
355-
x,
356-
variable_type = "ordinal",
357-
baseline_category,
358-
iter = 1e3,
359-
warmup = 1e3,
360-
pairwise_scale = 2.5,
361-
main_alpha = 0.5,
362-
main_beta = 0.5,
363-
edge_selection = TRUE,
364-
edge_prior = c("Bernoulli", "Beta-Bernoulli", "Stochastic-Block"),
365-
inclusion_probability = 0.5,
366-
beta_bernoulli_alpha = 1,
367-
beta_bernoulli_beta = 1,
368-
beta_bernoulli_alpha_between = 1,
369-
beta_bernoulli_beta_between = 1,
370-
dirichlet_alpha = 1,
371-
lambda = 1,
372-
na_action = c("listwise", "impute"),
373-
update_method = c("nuts", "adaptive-metropolis", "hamiltonian-mc"),
374-
target_accept,
375-
hmc_num_leapfrogs = 100,
376-
nuts_max_depth = 10,
377-
learn_mass_matrix = FALSE,
378-
chains = 4,
379-
cores = parallel::detectCores(),
380-
display_progress = c("per-chain", "total", "none"),
381-
seed = NULL,
382-
interaction_scale,
383-
burnin,
384-
save,
385-
threshold_alpha,
386-
threshold_beta
355+
x,
356+
variable_type = "ordinal",
357+
baseline_category,
358+
iter = 1e3,
359+
warmup = 1e3,
360+
pairwise_scale = 2.5,
361+
main_alpha = 0.5,
362+
main_beta = 0.5,
363+
edge_selection = TRUE,
364+
edge_prior = c("Bernoulli", "Beta-Bernoulli", "Stochastic-Block"),
365+
inclusion_probability = 0.5,
366+
beta_bernoulli_alpha = 1,
367+
beta_bernoulli_beta = 1,
368+
beta_bernoulli_alpha_between = 1,
369+
beta_bernoulli_beta_between = 1,
370+
dirichlet_alpha = 1,
371+
lambda = 1,
372+
na_action = c("listwise", "impute"),
373+
update_method = c("nuts", "adaptive-metropolis", "hamiltonian-mc"),
374+
target_accept,
375+
hmc_num_leapfrogs = 100,
376+
nuts_max_depth = 10,
377+
learn_mass_matrix = FALSE,
378+
chains = 4,
379+
cores = parallel::detectCores(),
380+
display_progress = c("per-chain", "total", "none"),
381+
seed = NULL,
382+
interaction_scale,
383+
burnin,
384+
save,
385+
threshold_alpha,
386+
threshold_beta
387387
) {
388-
if (hasArg(interaction_scale)) {
388+
if(hasArg(interaction_scale)) {
389389
lifecycle::deprecate_warn("0.1.6.0", "bgm(interaction_scale =)", "bgm(pairwise_scale =)")
390-
if (!hasArg(pairwise_scale)) {
390+
if(!hasArg(pairwise_scale)) {
391391
pairwise_scale = interaction_scale
392392
}
393393
}
394394

395-
if (hasArg(burnin)) {
395+
if(hasArg(burnin)) {
396396
lifecycle::deprecate_warn("0.1.6.0", "bgm(burnin =)", "bgm(warmup =)")
397-
if (!hasArg(warmup)) {
397+
if(!hasArg(warmup)) {
398398
warmup = burnin
399399
}
400400
}
401401

402-
if (hasArg(save)) {
402+
if(hasArg(save)) {
403403
lifecycle::deprecate_warn("0.1.6.0", "bgm(save =)")
404404
}
405405

406-
if (hasArg(threshold_alpha) || hasArg(threshold_beta)) {
407-
lifecycle::deprecate_warn("0.1.6.0",
408-
"bgm(threshold_alpha =, threshold_beta =)",
409-
"bgm(main_alpha =, main_beta =)"
406+
if(hasArg(threshold_alpha) || hasArg(threshold_beta)) {
407+
lifecycle::deprecate_warn(
408+
"0.1.6.0",
409+
"bgm(threshold_alpha =, threshold_beta =)",
410+
"bgm(main_alpha =, main_beta =)"
410411
)
411-
if (!hasArg(main_alpha)) main_alpha = threshold_alpha
412-
if (!hasArg(main_beta)) main_beta = threshold_beta
412+
if(!hasArg(main_alpha)) main_alpha = threshold_alpha
413+
if(!hasArg(main_beta)) main_beta = threshold_beta
413414
}
414415

415416
# Check update method
@@ -430,39 +431,45 @@ bgm = function(
430431
}
431432
}
432433

433-
#Check data input ------------------------------------------------------------
434-
if(!inherits(x, what = "matrix") && !inherits(x, what = "data.frame"))
434+
# Check data input ------------------------------------------------------------
435+
if(!inherits(x, what = "matrix") && !inherits(x, what = "data.frame")) {
435436
stop("The input x needs to be a matrix or dataframe.")
436-
if(inherits(x, what = "data.frame"))
437+
}
438+
if(inherits(x, what = "data.frame")) {
437439
x = data.matrix(x)
438-
if(ncol(x) < 2)
440+
}
441+
if(ncol(x) < 2) {
439442
stop("The matrix x should have more than one variable (columns).")
440-
if(nrow(x) < 2)
443+
}
444+
if(nrow(x) < 2) {
441445
stop("The matrix x should have more than one observation (rows).")
446+
}
442447

443-
#Check model input -----------------------------------------------------------
444-
model = check_model(x = x,
445-
variable_type = variable_type,
446-
baseline_category = baseline_category,
447-
pairwise_scale = pairwise_scale,
448-
main_alpha = main_alpha,
449-
main_beta = main_beta,
450-
edge_selection = edge_selection,
451-
edge_prior = edge_prior,
452-
inclusion_probability = inclusion_probability,
453-
beta_bernoulli_alpha = beta_bernoulli_alpha,
454-
beta_bernoulli_beta = beta_bernoulli_beta,
455-
beta_bernoulli_alpha_between = beta_bernoulli_alpha_between,
456-
beta_bernoulli_beta_between = beta_bernoulli_beta_between,
457-
dirichlet_alpha = dirichlet_alpha,
458-
lambda = lambda)
448+
# Check model input -----------------------------------------------------------
449+
model = check_model(
450+
x = x,
451+
variable_type = variable_type,
452+
baseline_category = baseline_category,
453+
pairwise_scale = pairwise_scale,
454+
main_alpha = main_alpha,
455+
main_beta = main_beta,
456+
edge_selection = edge_selection,
457+
edge_prior = edge_prior,
458+
inclusion_probability = inclusion_probability,
459+
beta_bernoulli_alpha = beta_bernoulli_alpha,
460+
beta_bernoulli_beta = beta_bernoulli_beta,
461+
beta_bernoulli_alpha_between = beta_bernoulli_alpha_between,
462+
beta_bernoulli_beta_between = beta_bernoulli_beta_between,
463+
dirichlet_alpha = dirichlet_alpha,
464+
lambda = lambda
465+
)
459466

460467
# check hyperparameters input
461468
# If user left them NULL, pass -1 to C++ (means: ignore between prior)
462-
if (is.null(beta_bernoulli_alpha_between) && is.null(beta_bernoulli_beta_between)) {
469+
if(is.null(beta_bernoulli_alpha_between) && is.null(beta_bernoulli_beta_between)) {
463470
beta_bernoulli_alpha_between <- -1.0
464-
beta_bernoulli_beta_between <- -1.0
465-
} else if (is.null(beta_bernoulli_alpha_between) || is.null(beta_bernoulli_beta_between)) {
471+
beta_bernoulli_beta_between <- -1.0
472+
} else if(is.null(beta_bernoulli_alpha_between) || is.null(beta_bernoulli_beta_between)) {
466473
stop("If you wish to specify different between and within cluster probabilites,
467474
provide both beta_bernoulli_alpha_between and beta_bernoulli_beta_between,
468475
otherwise leave both NULL.")
@@ -479,11 +486,12 @@ bgm = function(
479486
edge_prior = model$edge_prior
480487
inclusion_probability = model$inclusion_probability
481488

482-
#Check Gibbs input -----------------------------------------------------------
489+
# Check Gibbs input -----------------------------------------------------------
483490
check_positive_integer(iter, "iter")
484491
check_non_negative_integer(warmup, "warmup")
485-
if(warmup < 1e3)
492+
if(warmup < 1e3) {
486493
warning("The warmup parameter is set to a low value. This may lead to unreliable results. Reset to a minimum of 1000 iterations.")
494+
}
487495
warmup = max(warmup, 1e3) # Set minimum warmup to 1000 iterations
488496

489497
check_positive_integer(hmc_num_leapfrogs, "hmc_num_leapfrogs")
@@ -492,22 +500,27 @@ bgm = function(
492500
check_positive_integer(nuts_max_depth, "nuts_max_depth")
493501
nuts_max_depth = max(nuts_max_depth, 1) # Set minimum nuts_max_depth to 1
494502

495-
#Check na_action -------------------------------------------------------------
503+
# Check na_action -------------------------------------------------------------
496504
na_action_input = na_action
497505
na_action = try(match.arg(na_action), silent = TRUE)
498-
if(inherits(na_action, what = "try-error"))
499-
stop(paste0("The na_action argument should equal listwise or impute, not ",
500-
na_action_input,
501-
"."))
506+
if(inherits(na_action, what = "try-error")) {
507+
stop(paste0(
508+
"The na_action argument should equal listwise or impute, not ",
509+
na_action_input,
510+
"."
511+
))
512+
}
502513

503-
#Check display_progress ------------------------------------------------------
514+
# Check display_progress ------------------------------------------------------
504515
progress_type = progress_type_from_display_progress(display_progress)
505516

506-
#Format the data input -------------------------------------------------------
507-
data = reformat_data(x = x,
508-
na_action = na_action,
509-
variable_bool = variable_bool,
510-
baseline_category = baseline_category)
517+
# Format the data input -------------------------------------------------------
518+
data = reformat_data(
519+
x = x,
520+
na_action = na_action,
521+
variable_bool = variable_bool,
522+
baseline_category = baseline_category
523+
)
511524
x = data$x
512525
num_categories = data$num_categories
513526
missing_index = data$missing_index
@@ -520,44 +533,47 @@ bgm = function(
520533

521534
# Starting value of model matrix ---------------------------------------------
522535
indicator = matrix(1,
523-
nrow = num_variables,
524-
ncol = num_variables)
536+
nrow = num_variables,
537+
ncol = num_variables
538+
)
525539

526540

527-
#Starting values of interactions and thresholds (posterior mode) -------------
541+
# Starting values of interactions and thresholds (posterior mode) -------------
528542
interactions = matrix(0, nrow = num_variables, ncol = num_variables)
529543
thresholds = matrix(0, nrow = num_variables, ncol = max(num_categories))
530544

531-
#Precompute the number of observations per category for each variable --------
545+
# Precompute the number of observations per category for each variable --------
532546
counts_per_category = matrix(0,
533-
nrow = max(num_categories) + 1,
534-
ncol = num_variables)
547+
nrow = max(num_categories) + 1,
548+
ncol = num_variables
549+
)
535550
for(variable in 1:num_variables) {
536551
for(category in 0:num_categories[variable]) {
537552
counts_per_category[category + 1, variable] = sum(x[, variable] == category)
538553
}
539554
}
540555

541-
#Precompute the sufficient statistics for the two Blume-Capel parameters -----
556+
# Precompute the sufficient statistics for the two Blume-Capel parameters -----
542557
blume_capel_stats = matrix(0, nrow = 2, ncol = num_variables)
543558
if(any(!variable_bool)) {
544559
# Ordinal (variable_bool == TRUE) or Blume-Capel (variable_bool == FALSE)
545560
bc_vars = which(!variable_bool)
546561
for(i in bc_vars) {
547562
blume_capel_stats[1, i] = sum(x[, i])
548-
blume_capel_stats[2, i] = sum((x[, i] - baseline_category[i]) ^ 2)
563+
blume_capel_stats[2, i] = sum((x[, i] - baseline_category[i])^2)
549564
}
550565
}
551566
pairwise_stats = t(x) %*% x
552567

553568
# Index matrix used in the c++ functions ------------------------------------
554569
interaction_index_matrix = matrix(0,
555-
nrow = num_variables * (num_variables - 1) / 2,
556-
ncol = 3)
570+
nrow = num_variables * (num_variables - 1) / 2,
571+
ncol = 3
572+
)
557573
cntr = 0
558574
for(variable1 in 1:(num_variables - 1)) {
559575
for(variable2 in (variable1 + 1):num_variables) {
560-
cntr = cntr + 1
576+
cntr = cntr + 1
561577
interaction_index_matrix[cntr, 1] = cntr - 1
562578
interaction_index_matrix[cntr, 2] = variable1 - 1
563579
interaction_index_matrix[cntr, 3] = variable2 - 1
@@ -566,21 +582,21 @@ bgm = function(
566582

567583
pairwise_effect_indices = matrix(NA, nrow = num_variables, ncol = num_variables)
568584
tel = 0
569-
for (v1 in seq_len(num_variables - 1)) {
570-
for (v2 in seq((v1 + 1), num_variables)) {
585+
for(v1 in seq_len(num_variables - 1)) {
586+
for(v2 in seq((v1 + 1), num_variables)) {
571587
pairwise_effect_indices[v1, v2] = tel
572588
pairwise_effect_indices[v2, v1] = tel
573-
tel = tel + 1 # C++ starts at zero
589+
tel = tel + 1 # C++ starts at zero
574590
}
575591
}
576592

577-
#Setting the seed
578-
if (missing(seed) || is.null(seed)) {
593+
# Setting the seed
594+
if(missing(seed) || is.null(seed)) {
579595
# Draw a random seed if none provided
580596
seed = sample.int(.Machine$integer.max, 1)
581597
}
582598

583-
if (!is.numeric(seed) || length(seed) != 1 || is.na(seed) || seed < 0) {
599+
if(!is.numeric(seed) || length(seed) != 1 || is.na(seed) || seed < 0) {
584600
stop("Argument 'seed' must be a single non-negative integer.")
585601
}
586602

@@ -612,13 +628,13 @@ bgm = function(
612628

613629

614630
userInterrupt = any(vapply(out, FUN = `[[`, FUN.VALUE = logical(1L), "userInterrupt"))
615-
if (userInterrupt) {
631+
if(userInterrupt) {
616632
warning("Stopped sampling after user interrupt, results are likely uninterpretable.")
617633
# Try to prepare output, but catch any errors
618634
output <- tryCatch(
619635
prepare_output_bgm(
620636
out = out, x = x, num_categories = num_categories, iter = iter,
621-
data_columnnames = if (is.null(colnames(x))) paste0("Variable ", seq_len(ncol(x))) else colnames(x),
637+
data_columnnames = if(is.null(colnames(x))) paste0("Variable ", seq_len(ncol(x))) else colnames(x),
622638
is_ordinal_variable = variable_bool,
623639
warmup = warmup, pairwise_scale = pairwise_scale,
624640
main_alpha = main_alpha, main_beta = main_beta,
@@ -647,9 +663,9 @@ bgm = function(
647663
}
648664

649665
# Main output handler in the wrapper function
650-
output = prepare_output_bgm (
666+
output = prepare_output_bgm(
651667
out = out, x = x, num_categories = num_categories, iter = iter,
652-
data_columnnames = if (is.null(colnames(x))) paste0("Variable ", seq_len(ncol(x))) else colnames(x),
668+
data_columnnames = if(is.null(colnames(x))) paste0("Variable ", seq_len(ncol(x))) else colnames(x),
653669
is_ordinal_variable = variable_bool,
654670
warmup = warmup, pairwise_scale = pairwise_scale,
655671
main_alpha = main_alpha, main_beta = main_beta,
@@ -669,7 +685,7 @@ bgm = function(
669685
num_chains = chains
670686
)
671687

672-
if (update_method == "nuts") {
688+
if(update_method == "nuts") {
673689
nuts_diag = summarize_nuts_diagnostics(out, nuts_max_depth = nuts_max_depth)
674690
output$nuts_diag = nuts_diag
675691
}
@@ -678,21 +694,23 @@ bgm = function(
678694
# TODO: REMOVE after easybgm >= 0.2.2 is on CRAN
679695
# Compatibility shim for easybgm <= 0.2.1
680696
# -------------------------------------------------------------------
681-
if ("easybgm" %in% loadedNamespaces()) {
697+
if("easybgm" %in% loadedNamespaces()) {
682698
ebgm_version <- utils::packageVersion("easybgm")
683-
if (ebgm_version <= "0.2.1") {
684-
warning("bgms is running in compatibility mode for easybgm (<= 0.2.1). ",
685-
"This will be removed once easybgm >= 0.2.2 is on CRAN.")
699+
if(ebgm_version <= "0.2.1") {
700+
warning(
701+
"bgms is running in compatibility mode for easybgm (<= 0.2.1). ",
702+
"This will be removed once easybgm >= 0.2.2 is on CRAN."
703+
)
686704

687705
# Add legacy variables to output
688706
output$arguments$save <- TRUE
689-
if (edge_selection) {
707+
if(edge_selection) {
690708
output$indicator <- extract_indicators(output)
691709
}
692710
output$interactions <- extract_pairwise_interactions(output)
693-
output$thresholds <- extract_category_thresholds(output)
711+
output$thresholds <- extract_category_thresholds(output)
694712
}
695713
}
696714

697715
return(output)
698-
}
716+
}

0 commit comments

Comments
 (0)