@@ -1015,7 +1015,7 @@ void gibbs_update_step_bgm (
10151015 const arma::uvec& is_ordinal_variable,
10161016 const arma::ivec& baseline_category,
10171017 const int iteration,
1018- const std::string& update_method,
1018+ const UpdateMethod update_method,
10191019 const arma::imat& pairwise_effect_indices,
10201020 arma::imat& pairwise_stats,
10211021 const int hmc_num_leapfrogs,
@@ -1051,7 +1051,7 @@ void gibbs_update_step_bgm (
10511051 }
10521052
10531053 // Step 2a: Update interaction weights for active edges
1054- if (update_method == " adaptive-metropolis " ) {
1054+ if (update_method == adaptive_metropolis ) {
10551055 update_pairwise_effects_metropolis_bgm (
10561056 pairwise_effects, main_effects, inclusion_indicator, observations,
10571057 num_categories, proposal_sd_pairwise, adapt_pairwise, pairwise_scale,
@@ -1061,7 +1061,7 @@ void gibbs_update_step_bgm (
10611061 }
10621062
10631063 // Step 2b: Update main effect (main_effect) parameters
1064- if (update_method == " adaptive-metropolis " ) {
1064+ if (update_method == adaptive_metropolis ) {
10651065 update_main_effects_metropolis_bgm (
10661066 main_effects, observations, num_categories, counts_per_category,
10671067 blume_capel_stats, baseline_category, is_ordinal_variable,
@@ -1072,7 +1072,7 @@ void gibbs_update_step_bgm (
10721072 }
10731073
10741074 // Step 2: Update joint parameters if applicable
1075- if (update_method == " hamiltonian-mc " ) {
1075+ if (update_method == hamiltonian_mc ) {
10761076 update_hmc_bgm (
10771077 main_effects, pairwise_effects, inclusion_indicator, observations,
10781078 num_categories, counts_per_category, blume_capel_stats,
@@ -1081,7 +1081,7 @@ void gibbs_update_step_bgm (
10811081 iteration, adapt, learn_mass_matrix, schedule.selection_enabled (iteration),
10821082 rng
10831083 );
1084- } else if (update_method == " nuts" ) {
1084+ } else if (update_method == nuts) {
10851085 SamplerResult result = update_nuts_bgm (
10861086 main_effects, pairwise_effects, inclusion_indicator,
10871087 observations, num_categories, counts_per_category, blume_capel_stats,
@@ -1171,7 +1171,7 @@ Rcpp::List run_gibbs_sampler_bgm(
11711171 arma::imat observations,
11721172 const arma::ivec& num_categories,
11731173 const double pairwise_scale,
1174- const std::string& edge_prior,
1174+ const EdgePrior edge_prior,
11751175 arma::mat inclusion_probability,
11761176 const double beta_bernoulli_alpha,
11771177 const double beta_bernoulli_beta,
@@ -1189,7 +1189,7 @@ Rcpp::List run_gibbs_sampler_bgm(
11891189 const arma::uvec& is_ordinal_variable,
11901190 const arma::ivec& baseline_category,
11911191 bool edge_selection,
1192- const std::string& update_method,
1192+ const UpdateMethod update_method,
11931193 const arma::imat pairwise_effect_indices,
11941194 const double target_accept,
11951195 arma::imat pairwise_stats,
@@ -1223,7 +1223,7 @@ Rcpp::List run_gibbs_sampler_bgm(
12231223 if (edge_selection) {
12241224 indicator_samples.set_size (iter, num_pairwise);
12251225 }
1226- if (edge_selection && edge_prior == " Stochastic-Block " ) {
1226+ if (edge_selection && edge_prior == Stochastic_Block ) {
12271227 allocation_samples.set_size (iter, num_variables);
12281228 }
12291229
@@ -1245,7 +1245,7 @@ Rcpp::List run_gibbs_sampler_bgm(
12451245 arma::vec log_Vn (1 );
12461246
12471247 // --- Initialize SBM prior if applicable
1248- if (edge_prior == " Stochastic-Block " ) {
1248+ if (edge_prior == Stochastic_Block ) {
12491249 cluster_allocations[0 ] = 0 ;
12501250 cluster_allocations[1 ] = 1 ;
12511251 for (int i = 2 ; i < num_variables; i++) {
@@ -1273,7 +1273,7 @@ Rcpp::List run_gibbs_sampler_bgm(
12731273
12741274 // --- Optional HMC/NUTS warmup stage
12751275 double initial_step_size_joint = 1.0 ;
1276- if (update_method == " hamiltonian-mc " || update_method == " nuts" ) {
1276+ if (update_method == hamiltonian_mc || update_method == nuts) {
12771277 initial_step_size_joint = find_initial_stepsize_bgm (
12781278 main_effects, pairwise_effects, inclusion_indicator, observations,
12791279 num_categories, counts_per_category, blume_capel_stats,
@@ -1283,7 +1283,7 @@ Rcpp::List run_gibbs_sampler_bgm(
12831283 }
12841284
12851285 // --- Warmup scheduling + adaptation controller
1286- WarmupSchedule warmup_schedule (warmup, edge_selection, (update_method != " adaptive-metropolis " ));
1286+ WarmupSchedule warmup_schedule (warmup, edge_selection, (update_method != adaptive_metropolis ));
12871287 HMCAdaptationController adapt_joint (
12881288 num_main + num_pairwise, initial_step_size_joint, target_accept,
12891289 warmup_schedule, learn_mass_matrix
@@ -1339,7 +1339,7 @@ Rcpp::List run_gibbs_sampler_bgm(
13391339
13401340 // --- Update edge probabilities under the prior (if edge selection is active)
13411341 if (warmup_schedule.selection_enabled (iteration)) {
1342- if (edge_prior == " Beta-Bernoulli " ) {
1342+ if (edge_prior == Beta_Bernoulli ) {
13431343 int num_edges_included = 0 ;
13441344 for (int i = 0 ; i < num_variables - 1 ; i++)
13451345 for (int j = i + 1 ; j < num_variables; j++)
@@ -1354,7 +1354,7 @@ Rcpp::List run_gibbs_sampler_bgm(
13541354 for (int j = i + 1 ; j < num_variables; j++)
13551355 inclusion_probability (i, j) = inclusion_probability (j, i) = prob;
13561356
1357- } else if (edge_prior == " Stochastic-Block " ) {
1357+ } else if (edge_prior == Stochastic_Block ) {
13581358 cluster_allocations = block_allocations_mfm_sbm (
13591359 cluster_allocations, num_variables, log_Vn, cluster_prob,
13601360 arma::conv_to<arma::umat>::from (inclusion_indicator), dirichlet_alpha,
@@ -1396,7 +1396,7 @@ Rcpp::List run_gibbs_sampler_bgm(
13961396 }
13971397 }
13981398
1399- if (edge_selection && edge_prior == " Stochastic-Block " ) {
1399+ if (edge_selection && edge_prior == Stochastic_Block ) {
14001400 for (int v = 0 ; v < num_variables; v++) {
14011401 allocation_samples (sample_index, v) = cluster_allocations[v] + 1 ;
14021402 }
@@ -1408,7 +1408,7 @@ Rcpp::List run_gibbs_sampler_bgm(
14081408 out[" main_samples" ] = main_effect_samples;
14091409 out[" pairwise_samples" ] = pairwise_effect_samples;
14101410
1411- if (update_method == " nuts" ) {
1411+ if (update_method == nuts) {
14121412 out[" treedepth__" ] = treedepth_samples;
14131413 out[" divergent__" ] = divergent_samples;
14141414 out[" energy__" ] = energy_samples;
@@ -1418,7 +1418,7 @@ Rcpp::List run_gibbs_sampler_bgm(
14181418 out[" indicator_samples" ] = indicator_samples;
14191419 }
14201420
1421- if (edge_selection && edge_prior == " Stochastic-Block " ) {
1421+ if (edge_selection && edge_prior == Stochastic_Block ) {
14221422 out[" allocations" ] = allocation_samples;
14231423 }
14241424
0 commit comments