352352# '
353353# ' @export
354354bgm = function (
355- x ,
356- variable_type = " ordinal" ,
357- baseline_category ,
358- iter = 1e3 ,
359- warmup = 1e3 ,
360- pairwise_scale = 2.5 ,
361- main_alpha = 0.5 ,
362- main_beta = 0.5 ,
363- edge_selection = TRUE ,
364- edge_prior = c(" Bernoulli" , " Beta-Bernoulli" , " Stochastic-Block" ),
365- inclusion_probability = 0.5 ,
366- beta_bernoulli_alpha = 1 ,
367- beta_bernoulli_beta = 1 ,
368- beta_bernoulli_alpha_between = 1 ,
369- beta_bernoulli_beta_between = 1 ,
370- dirichlet_alpha = 1 ,
371- lambda = 1 ,
372- na_action = c(" listwise" , " impute" ),
373- update_method = c(" nuts" , " adaptive-metropolis" , " hamiltonian-mc" ),
374- target_accept ,
375- hmc_num_leapfrogs = 100 ,
376- nuts_max_depth = 10 ,
377- learn_mass_matrix = FALSE ,
378- chains = 4 ,
379- cores = parallel :: detectCores(),
380- display_progress = c(" per-chain" , " total" , " none" ),
381- seed = NULL ,
382- interaction_scale ,
383- burnin ,
384- save ,
385- threshold_alpha ,
386- threshold_beta
355+ x ,
356+ variable_type = " ordinal" ,
357+ baseline_category ,
358+ iter = 1e3 ,
359+ warmup = 1e3 ,
360+ pairwise_scale = 2.5 ,
361+ main_alpha = 0.5 ,
362+ main_beta = 0.5 ,
363+ edge_selection = TRUE ,
364+ edge_prior = c(" Bernoulli" , " Beta-Bernoulli" , " Stochastic-Block" ),
365+ inclusion_probability = 0.5 ,
366+ beta_bernoulli_alpha = 1 ,
367+ beta_bernoulli_beta = 1 ,
368+ beta_bernoulli_alpha_between = 1 ,
369+ beta_bernoulli_beta_between = 1 ,
370+ dirichlet_alpha = 1 ,
371+ lambda = 1 ,
372+ na_action = c(" listwise" , " impute" ),
373+ update_method = c(" nuts" , " adaptive-metropolis" , " hamiltonian-mc" ),
374+ target_accept ,
375+ hmc_num_leapfrogs = 100 ,
376+ nuts_max_depth = 10 ,
377+ learn_mass_matrix = FALSE ,
378+ chains = 4 ,
379+ cores = parallel :: detectCores(),
380+ display_progress = c(" per-chain" , " total" , " none" ),
381+ seed = NULL ,
382+ interaction_scale ,
383+ burnin ,
384+ save ,
385+ threshold_alpha ,
386+ threshold_beta
387387) {
388- if (hasArg(interaction_scale )) {
388+ if (hasArg(interaction_scale )) {
389389 lifecycle :: deprecate_warn(" 0.1.6.0" , " bgm(interaction_scale =)" , " bgm(pairwise_scale =)" )
390- if (! hasArg(pairwise_scale )) {
390+ if (! hasArg(pairwise_scale )) {
391391 pairwise_scale = interaction_scale
392392 }
393393 }
394394
395- if (hasArg(burnin )) {
395+ if (hasArg(burnin )) {
396396 lifecycle :: deprecate_warn(" 0.1.6.0" , " bgm(burnin =)" , " bgm(warmup =)" )
397- if (! hasArg(warmup )) {
397+ if (! hasArg(warmup )) {
398398 warmup = burnin
399399 }
400400 }
401401
402- if (hasArg(save )) {
402+ if (hasArg(save )) {
403403 lifecycle :: deprecate_warn(" 0.1.6.0" , " bgm(save =)" )
404404 }
405405
406- if (hasArg(threshold_alpha ) || hasArg(threshold_beta )) {
407- lifecycle :: deprecate_warn(" 0.1.6.0" ,
408- " bgm(threshold_alpha =, threshold_beta =)" ,
409- " bgm(main_alpha =, main_beta =)"
406+ if (hasArg(threshold_alpha ) || hasArg(threshold_beta )) {
407+ lifecycle :: deprecate_warn(
408+ " 0.1.6.0" ,
409+ " bgm(threshold_alpha =, threshold_beta =)" ,
410+ " bgm(main_alpha =, main_beta =)"
410411 )
411- if (! hasArg(main_alpha )) main_alpha = threshold_alpha
412- if (! hasArg(main_beta )) main_beta = threshold_beta
412+ if (! hasArg(main_alpha )) main_alpha = threshold_alpha
413+ if (! hasArg(main_beta )) main_beta = threshold_beta
413414 }
414415
415416 # Check update method
@@ -430,39 +431,45 @@ bgm = function(
430431 }
431432 }
432433
433- # Check data input ------------------------------------------------------------
434- if (! inherits(x , what = " matrix" ) && ! inherits(x , what = " data.frame" ))
434+ # Check data input ------------------------------------------------------------
435+ if (! inherits(x , what = " matrix" ) && ! inherits(x , what = " data.frame" )) {
435436 stop(" The input x needs to be a matrix or dataframe." )
436- if (inherits(x , what = " data.frame" ))
437+ }
438+ if (inherits(x , what = " data.frame" )) {
437439 x = data.matrix(x )
438- if (ncol(x ) < 2 )
440+ }
441+ if (ncol(x ) < 2 ) {
439442 stop(" The matrix x should have more than one variable (columns)." )
440- if (nrow(x ) < 2 )
443+ }
444+ if (nrow(x ) < 2 ) {
441445 stop(" The matrix x should have more than one observation (rows)." )
446+ }
442447
443- # Check model input -----------------------------------------------------------
444- model = check_model(x = x ,
445- variable_type = variable_type ,
446- baseline_category = baseline_category ,
447- pairwise_scale = pairwise_scale ,
448- main_alpha = main_alpha ,
449- main_beta = main_beta ,
450- edge_selection = edge_selection ,
451- edge_prior = edge_prior ,
452- inclusion_probability = inclusion_probability ,
453- beta_bernoulli_alpha = beta_bernoulli_alpha ,
454- beta_bernoulli_beta = beta_bernoulli_beta ,
455- beta_bernoulli_alpha_between = beta_bernoulli_alpha_between ,
456- beta_bernoulli_beta_between = beta_bernoulli_beta_between ,
457- dirichlet_alpha = dirichlet_alpha ,
458- lambda = lambda )
448+ # Check model input -----------------------------------------------------------
449+ model = check_model(
450+ x = x ,
451+ variable_type = variable_type ,
452+ baseline_category = baseline_category ,
453+ pairwise_scale = pairwise_scale ,
454+ main_alpha = main_alpha ,
455+ main_beta = main_beta ,
456+ edge_selection = edge_selection ,
457+ edge_prior = edge_prior ,
458+ inclusion_probability = inclusion_probability ,
459+ beta_bernoulli_alpha = beta_bernoulli_alpha ,
460+ beta_bernoulli_beta = beta_bernoulli_beta ,
461+ beta_bernoulli_alpha_between = beta_bernoulli_alpha_between ,
462+ beta_bernoulli_beta_between = beta_bernoulli_beta_between ,
463+ dirichlet_alpha = dirichlet_alpha ,
464+ lambda = lambda
465+ )
459466
460467 # check hyperparameters input
461468 # If user left them NULL, pass -1 to C++ (means: ignore between prior)
462- if (is.null(beta_bernoulli_alpha_between ) && is.null(beta_bernoulli_beta_between )) {
469+ if (is.null(beta_bernoulli_alpha_between ) && is.null(beta_bernoulli_beta_between )) {
463470 beta_bernoulli_alpha_between <- - 1.0
464- beta_bernoulli_beta_between <- - 1.0
465- } else if (is.null(beta_bernoulli_alpha_between ) || is.null(beta_bernoulli_beta_between )) {
471+ beta_bernoulli_beta_between <- - 1.0
472+ } else if (is.null(beta_bernoulli_alpha_between ) || is.null(beta_bernoulli_beta_between )) {
466473 stop(" If you wish to specify different between and within cluster probabilites,
467474 provide both beta_bernoulli_alpha_between and beta_bernoulli_beta_between,
468475 otherwise leave both NULL." )
@@ -479,11 +486,12 @@ bgm = function(
479486 edge_prior = model $ edge_prior
480487 inclusion_probability = model $ inclusion_probability
481488
482- # Check Gibbs input -----------------------------------------------------------
489+ # Check Gibbs input -----------------------------------------------------------
483490 check_positive_integer(iter , " iter" )
484491 check_non_negative_integer(warmup , " warmup" )
485- if (warmup < 1e3 )
492+ if (warmup < 1e3 ) {
486493 warning(" The warmup parameter is set to a low value. This may lead to unreliable results. Reset to a minimum of 1000 iterations." )
494+ }
487495 warmup = max(warmup , 1e3 ) # Set minimum warmup to 1000 iterations
488496
489497 check_positive_integer(hmc_num_leapfrogs , " hmc_num_leapfrogs" )
@@ -492,22 +500,27 @@ bgm = function(
492500 check_positive_integer(nuts_max_depth , " nuts_max_depth" )
493501 nuts_max_depth = max(nuts_max_depth , 1 ) # Set minimum nuts_max_depth to 1
494502
495- # Check na_action -------------------------------------------------------------
503+ # Check na_action -------------------------------------------------------------
496504 na_action_input = na_action
497505 na_action = try(match.arg(na_action ), silent = TRUE )
498- if (inherits(na_action , what = " try-error" ))
499- stop(paste0(" The na_action argument should equal listwise or impute, not " ,
500- na_action_input ,
501- " ." ))
506+ if (inherits(na_action , what = " try-error" )) {
507+ stop(paste0(
508+ " The na_action argument should equal listwise or impute, not " ,
509+ na_action_input ,
510+ " ."
511+ ))
512+ }
502513
503- # Check display_progress ------------------------------------------------------
514+ # Check display_progress ------------------------------------------------------
504515 progress_type = progress_type_from_display_progress(display_progress )
505516
506- # Format the data input -------------------------------------------------------
507- data = reformat_data(x = x ,
508- na_action = na_action ,
509- variable_bool = variable_bool ,
510- baseline_category = baseline_category )
517+ # Format the data input -------------------------------------------------------
518+ data = reformat_data(
519+ x = x ,
520+ na_action = na_action ,
521+ variable_bool = variable_bool ,
522+ baseline_category = baseline_category
523+ )
511524 x = data $ x
512525 num_categories = data $ num_categories
513526 missing_index = data $ missing_index
@@ -520,44 +533,47 @@ bgm = function(
520533
521534 # Starting value of model matrix ---------------------------------------------
522535 indicator = matrix (1 ,
523- nrow = num_variables ,
524- ncol = num_variables )
536+ nrow = num_variables ,
537+ ncol = num_variables
538+ )
525539
526540
527- # Starting values of interactions and thresholds (posterior mode) -------------
541+ # Starting values of interactions and thresholds (posterior mode) -------------
528542 interactions = matrix (0 , nrow = num_variables , ncol = num_variables )
529543 thresholds = matrix (0 , nrow = num_variables , ncol = max(num_categories ))
530544
531- # Precompute the number of observations per category for each variable --------
545+ # Precompute the number of observations per category for each variable --------
532546 counts_per_category = matrix (0 ,
533- nrow = max(num_categories ) + 1 ,
534- ncol = num_variables )
547+ nrow = max(num_categories ) + 1 ,
548+ ncol = num_variables
549+ )
535550 for (variable in 1 : num_variables ) {
536551 for (category in 0 : num_categories [variable ]) {
537552 counts_per_category [category + 1 , variable ] = sum(x [, variable ] == category )
538553 }
539554 }
540555
541- # Precompute the sufficient statistics for the two Blume-Capel parameters -----
556+ # Precompute the sufficient statistics for the two Blume-Capel parameters -----
542557 blume_capel_stats = matrix (0 , nrow = 2 , ncol = num_variables )
543558 if (any(! variable_bool )) {
544559 # Ordinal (variable_bool == TRUE) or Blume-Capel (variable_bool == FALSE)
545560 bc_vars = which(! variable_bool )
546561 for (i in bc_vars ) {
547562 blume_capel_stats [1 , i ] = sum(x [, i ])
548- blume_capel_stats [2 , i ] = sum((x [, i ] - baseline_category [i ]) ^ 2 )
563+ blume_capel_stats [2 , i ] = sum((x [, i ] - baseline_category [i ])^ 2 )
549564 }
550565 }
551566 pairwise_stats = t(x ) %*% x
552567
553568 # Index matrix used in the c++ functions ------------------------------------
554569 interaction_index_matrix = matrix (0 ,
555- nrow = num_variables * (num_variables - 1 ) / 2 ,
556- ncol = 3 )
570+ nrow = num_variables * (num_variables - 1 ) / 2 ,
571+ ncol = 3
572+ )
557573 cntr = 0
558574 for (variable1 in 1 : (num_variables - 1 )) {
559575 for (variable2 in (variable1 + 1 ): num_variables ) {
560- cntr = cntr + 1
576+ cntr = cntr + 1
561577 interaction_index_matrix [cntr , 1 ] = cntr - 1
562578 interaction_index_matrix [cntr , 2 ] = variable1 - 1
563579 interaction_index_matrix [cntr , 3 ] = variable2 - 1
@@ -566,21 +582,21 @@ bgm = function(
566582
567583 pairwise_effect_indices = matrix (NA , nrow = num_variables , ncol = num_variables )
568584 tel = 0
569- for (v1 in seq_len(num_variables - 1 )) {
570- for (v2 in seq((v1 + 1 ), num_variables )) {
585+ for (v1 in seq_len(num_variables - 1 )) {
586+ for (v2 in seq((v1 + 1 ), num_variables )) {
571587 pairwise_effect_indices [v1 , v2 ] = tel
572588 pairwise_effect_indices [v2 , v1 ] = tel
573- tel = tel + 1 # C++ starts at zero
589+ tel = tel + 1 # C++ starts at zero
574590 }
575591 }
576592
577- # Setting the seed
578- if (missing(seed ) || is.null(seed )) {
593+ # Setting the seed
594+ if (missing(seed ) || is.null(seed )) {
579595 # Draw a random seed if none provided
580596 seed = sample.int(.Machine $ integer.max , 1 )
581597 }
582598
583- if (! is.numeric(seed ) || length(seed ) != 1 || is.na(seed ) || seed < 0 ) {
599+ if (! is.numeric(seed ) || length(seed ) != 1 || is.na(seed ) || seed < 0 ) {
584600 stop(" Argument 'seed' must be a single non-negative integer." )
585601 }
586602
@@ -612,13 +628,13 @@ bgm = function(
612628
613629
614630 userInterrupt = any(vapply(out , FUN = `[[` , FUN.VALUE = logical (1L ), " userInterrupt" ))
615- if (userInterrupt ) {
631+ if (userInterrupt ) {
616632 warning(" Stopped sampling after user interrupt, results are likely uninterpretable." )
617633 # Try to prepare output, but catch any errors
618634 output <- tryCatch(
619635 prepare_output_bgm(
620636 out = out , x = x , num_categories = num_categories , iter = iter ,
621- data_columnnames = if (is.null(colnames(x ))) paste0(" Variable " , seq_len(ncol(x ))) else colnames(x ),
637+ data_columnnames = if (is.null(colnames(x ))) paste0(" Variable " , seq_len(ncol(x ))) else colnames(x ),
622638 is_ordinal_variable = variable_bool ,
623639 warmup = warmup , pairwise_scale = pairwise_scale ,
624640 main_alpha = main_alpha , main_beta = main_beta ,
@@ -647,9 +663,9 @@ bgm = function(
647663 }
648664
649665 # Main output handler in the wrapper function
650- output = prepare_output_bgm (
666+ output = prepare_output_bgm(
651667 out = out , x = x , num_categories = num_categories , iter = iter ,
652- data_columnnames = if (is.null(colnames(x ))) paste0(" Variable " , seq_len(ncol(x ))) else colnames(x ),
668+ data_columnnames = if (is.null(colnames(x ))) paste0(" Variable " , seq_len(ncol(x ))) else colnames(x ),
653669 is_ordinal_variable = variable_bool ,
654670 warmup = warmup , pairwise_scale = pairwise_scale ,
655671 main_alpha = main_alpha , main_beta = main_beta ,
@@ -669,7 +685,7 @@ bgm = function(
669685 num_chains = chains
670686 )
671687
672- if (update_method == " nuts" ) {
688+ if (update_method == " nuts" ) {
673689 nuts_diag = summarize_nuts_diagnostics(out , nuts_max_depth = nuts_max_depth )
674690 output $ nuts_diag = nuts_diag
675691 }
@@ -678,21 +694,23 @@ bgm = function(
678694 # TODO: REMOVE after easybgm >= 0.2.2 is on CRAN
679695 # Compatibility shim for easybgm <= 0.2.1
680696 # -------------------------------------------------------------------
681- if (" easybgm" %in% loadedNamespaces()) {
697+ if (" easybgm" %in% loadedNamespaces()) {
682698 ebgm_version <- utils :: packageVersion(" easybgm" )
683- if (ebgm_version < = " 0.2.1" ) {
684- warning(" bgms is running in compatibility mode for easybgm (<= 0.2.1). " ,
685- " This will be removed once easybgm >= 0.2.2 is on CRAN." )
699+ if (ebgm_version < = " 0.2.1" ) {
700+ warning(
701+ " bgms is running in compatibility mode for easybgm (<= 0.2.1). " ,
702+ " This will be removed once easybgm >= 0.2.2 is on CRAN."
703+ )
686704
687705 # Add legacy variables to output
688706 output $ arguments $ save <- TRUE
689- if (edge_selection ) {
707+ if (edge_selection ) {
690708 output $ indicator <- extract_indicators(output )
691709 }
692710 output $ interactions <- extract_pairwise_interactions(output )
693- output $ thresholds <- extract_category_thresholds(output )
711+ output $ thresholds <- extract_category_thresholds(output )
694712 }
695713 }
696714
697715 return (output )
698- }
716+ }
0 commit comments