Skip to content

Commit a79f2be

Browse files
committed
use enums instead of strings
1 parent a9d2dd4 commit a79f2be

File tree

7 files changed

+71
-36
lines changed

7 files changed

+71
-36
lines changed

src/bgmCompare_parallel.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "progress_manager.h"
1010
#include "sampler_output.h"
1111
#include "mcmc_adaptation.h"
12+
#include "common_helpers.h"
1213

1314
using namespace Rcpp;
1415
using namespace RcppParallel;
@@ -131,7 +132,7 @@ struct GibbsCompareChainRunner : public Worker {
131132
const arma::mat& inclusion_probability_master;
132133
// RNG seeds
133134
const std::vector<SafeRNG>& chain_rngs;
134-
const std::string& update_method;
135+
const UpdateMethod update_method;
135136
const int hmc_num_leapfrogs;
136137
ProgressManager& pm;
137138
// output
@@ -169,7 +170,7 @@ struct GibbsCompareChainRunner : public Worker {
169170
const arma::imat& interaction_index_matrix,
170171
const arma::mat& inclusion_probability_master,
171172
const std::vector<SafeRNG>& chain_rngs,
172-
const std::string& update_method,
173+
const UpdateMethod update_method,
173174
const int hmc_num_leapfrogs,
174175
ProgressManager& pm,
175176
std::vector<ChainResultCompare>& results
@@ -394,8 +395,10 @@ Rcpp::List run_bgmCompare_parallel(
394395
chain_rngs[c] = SafeRNG(seed + c);
395396
}
396397

398+
UpdateMethod update_method_enum = update_method_from_string(update_method);
399+
397400
// only used to determine the total no. warmup iterations, a bit hacky
398-
WarmupSchedule warmup_schedule_temp(warmup, difference_selection, (update_method != "adaptive-metropolis"));
401+
WarmupSchedule warmup_schedule_temp(warmup, difference_selection, (update_method_enum != adaptive_metropolis));
399402
int total_warmup = warmup_schedule_temp.total_warmup;
400403
ProgressManager pm(num_chains, iter, total_warmup, 50, progress_type);
401404

@@ -408,7 +411,7 @@ Rcpp::List run_bgmCompare_parallel(
408411
baseline_category, difference_selection, main_effect_indices,
409412
pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix,
410413
projection, group_membership, group_indices, interaction_index_matrix,
411-
inclusion_probability, chain_rngs, update_method, hmc_num_leapfrogs,
414+
inclusion_probability, chain_rngs, update_method_enum, hmc_num_leapfrogs,
412415
pm, results
413416
);
414417

src/bgmCompare_sampler.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,7 +1389,7 @@ void gibbs_update_step_bgmcompare (
13891389
SafeRNG& rng,
13901390
arma::mat& inclusion_probability,
13911391
int hmc_nuts_leapfrogs,
1392-
const std::string& update_method,
1392+
const UpdateMethod update_method,
13931393
arma::mat& proposal_sd_main,
13941394
arma::mat& proposal_sd_pair,
13951395
const arma::imat& index
@@ -1416,7 +1416,7 @@ void gibbs_update_step_bgmcompare (
14161416
}
14171417

14181418
// Step 2: Update parameters
1419-
if(update_method == "adaptive-metropolis") {
1419+
if(update_method == adaptive_metropolis) {
14201420
update_main_effects_metropolis_bgmcompare (
14211421
main_effects, pairwise_effects, main_effect_indices,
14221422
pairwise_effect_indices, inclusion_indicator, projection,
@@ -1434,7 +1434,7 @@ void gibbs_update_step_bgmcompare (
14341434
pairwise_scale, difference_scale, iteration, rwm_adapt_pair, rng,
14351435
proposal_sd_pair
14361436
);
1437-
} else if (update_method == "hamiltonian-mc") {
1437+
} else if (update_method == hamiltonian_mc) {
14381438
update_hmc_bgmcompare(
14391439
main_effects, pairwise_effects, main_effect_indices,
14401440
pairwise_effect_indices, inclusion_indicator, projection, num_categories,
@@ -1444,7 +1444,7 @@ void gibbs_update_step_bgmcompare (
14441444
main_beta, hmc_nuts_leapfrogs, iteration, hmc_adapt, learn_mass_matrix,
14451445
schedule.selection_enabled(iteration), rng
14461446
);
1447-
} else if (update_method == "nuts") {
1447+
} else if (update_method == nuts) {
14481448
SamplerResult result = update_nuts_bgmcompare(
14491449
main_effects, pairwise_effects, main_effect_indices,
14501450
pairwise_effect_indices, inclusion_indicator, projection, num_categories,
@@ -1577,7 +1577,7 @@ SamplerOutput run_gibbs_sampler_bgmCompare(
15771577
const arma::imat& interaction_index_matrix,
15781578
arma::mat inclusion_probability,
15791579
SafeRNG& rng,
1580-
const std::string& update_method,
1580+
const UpdateMethod update_method,
15811581
const int hmc_num_leapfrogs,
15821582
ProgressManager& pm
15831583
) {
@@ -1618,7 +1618,7 @@ SamplerOutput run_gibbs_sampler_bgmCompare(
16181618

16191619
// --- Optional HMC/NUTS warmup stage
16201620
double initial_step_size = 1.0;
1621-
if (update_method == "hamiltonian-mc" || update_method == "nuts") {
1621+
if (update_method == hamiltonian_mc || update_method == nuts) {
16221622
initial_step_size = find_initial_stepsize_bgmcompare(
16231623
main_effects, pairwise_effects, main_effect_indices,
16241624
pairwise_effect_indices, inclusion_indicator, projection, num_categories,

src/bgmCompare_sampler.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <RcppArmadillo.h>
4+
#include "common_helpers.h"
45
#include <string>
56

67
struct SamplerOutput;
@@ -40,7 +41,7 @@ SamplerOutput run_gibbs_sampler_bgmCompare(
4041
const arma::imat& interaction_index_matrix,
4142
arma::mat inclusion_probability,
4243
SafeRNG& rng,
43-
const std::string& update_method,
44+
const UpdateMethod update_method,
4445
const int hmc_num_leapfrogs,
4546
ProgressManager& pm
4647
);

src/bgm_parallel.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <string>
99
#include "progress_manager.h"
1010
#include "mcmc_adaptation.h"
11+
#include "common_helpers.h"
1112

1213
using namespace Rcpp;
1314
using namespace RcppParallel;
@@ -61,7 +62,7 @@ struct GibbsChainRunner : public Worker {
6162
const arma::imat& observations;
6263
const arma::ivec& num_categories;
6364
double pairwise_scale;
64-
const std::string& edge_prior;
65+
const EdgePrior edge_prior;
6566
const arma::mat& inclusion_probability;
6667
double beta_bernoulli_alpha;
6768
double beta_bernoulli_beta;
@@ -79,7 +80,7 @@ struct GibbsChainRunner : public Worker {
7980
const arma::uvec& is_ordinal_variable;
8081
const arma::ivec& baseline_category;
8182
bool edge_selection;
82-
const std::string& update_method;
83+
const UpdateMethod update_method;
8384
const arma::imat& pairwise_effect_indices;
8485
double target_accept;
8586
const arma::imat& pairwise_stats;
@@ -98,7 +99,7 @@ struct GibbsChainRunner : public Worker {
9899
const arma::imat& observations,
99100
const arma::ivec& num_categories,
100101
double pairwise_scale,
101-
const std::string& edge_prior,
102+
const EdgePrior edge_prior,
102103
const arma::mat& inclusion_probability,
103104
double beta_bernoulli_alpha,
104105
double beta_bernoulli_beta,
@@ -116,7 +117,7 @@ struct GibbsChainRunner : public Worker {
116117
const arma::uvec& is_ordinal_variable,
117118
const arma::ivec& baseline_category,
118119
bool edge_selection,
119-
const std::string& update_method,
120+
const UpdateMethod update_method,
120121
const arma::imat& pairwise_effect_indices,
121122
double target_accept,
122123
const arma::imat& pairwise_stats,
@@ -308,18 +309,20 @@ Rcpp::List run_bgm_parallel(
308309
chain_rngs[c] = SafeRNG(seed + c);
309310
}
310311

312+
UpdateMethod update_method_enum = update_method_from_string(update_method);
313+
EdgePrior edge_prior_enum = edge_prior_from_string(edge_prior);
311314
// only used to determine the total no. warmup iterations, a bit hacky
312-
WarmupSchedule warmup_schedule_temp(warmup, edge_selection, (update_method != "adaptive-metropolis"));
315+
WarmupSchedule warmup_schedule_temp(warmup, edge_selection, (update_method_enum != adaptive_metropolis));
313316
int total_warmup = warmup_schedule_temp.total_warmup;
314317
ProgressManager pm(num_chains, iter, total_warmup, 50, progress_type);
315318

316319
GibbsChainRunner worker(
317-
observations, num_categories, pairwise_scale, edge_prior,
320+
observations, num_categories, pairwise_scale, edge_prior_enum,
318321
inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta,
319322
dirichlet_alpha, lambda, interaction_index_matrix, iter, warmup,
320323
counts_per_category, blume_capel_stats, main_alpha, main_beta,
321324
na_impute, missing_index, is_ordinal_variable, baseline_category,
322-
edge_selection, update_method, pairwise_effect_indices, target_accept,
325+
edge_selection, update_method_enum, pairwise_effect_indices, target_accept,
323326
pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix,
324327
chain_rngs, pm, results
325328
);

src/bgm_sampler.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,7 +1015,7 @@ void gibbs_update_step_bgm (
10151015
const arma::uvec& is_ordinal_variable,
10161016
const arma::ivec& baseline_category,
10171017
const int iteration,
1018-
const std::string& update_method,
1018+
const UpdateMethod update_method,
10191019
const arma::imat& pairwise_effect_indices,
10201020
arma::imat& pairwise_stats,
10211021
const int hmc_num_leapfrogs,
@@ -1051,7 +1051,7 @@ void gibbs_update_step_bgm (
10511051
}
10521052

10531053
// Step 2a: Update interaction weights for active edges
1054-
if (update_method == "adaptive-metropolis") {
1054+
if (update_method == adaptive_metropolis) {
10551055
update_pairwise_effects_metropolis_bgm (
10561056
pairwise_effects, main_effects, inclusion_indicator, observations,
10571057
num_categories, proposal_sd_pairwise, adapt_pairwise, pairwise_scale,
@@ -1061,7 +1061,7 @@ void gibbs_update_step_bgm (
10611061
}
10621062

10631063
// Step 2b: Update main effect (main_effect) parameters
1064-
if (update_method == "adaptive-metropolis") {
1064+
if (update_method == adaptive_metropolis) {
10651065
update_main_effects_metropolis_bgm (
10661066
main_effects, observations, num_categories, counts_per_category,
10671067
blume_capel_stats, baseline_category, is_ordinal_variable,
@@ -1072,7 +1072,7 @@ void gibbs_update_step_bgm (
10721072
}
10731073

10741074
// Step 2: Update joint parameters if applicable
1075-
if (update_method == "hamiltonian-mc") {
1075+
if (update_method == hamiltonian_mc) {
10761076
update_hmc_bgm(
10771077
main_effects, pairwise_effects, inclusion_indicator, observations,
10781078
num_categories, counts_per_category, blume_capel_stats,
@@ -1081,7 +1081,7 @@ void gibbs_update_step_bgm (
10811081
iteration, adapt, learn_mass_matrix, schedule.selection_enabled(iteration),
10821082
rng
10831083
);
1084-
} else if (update_method == "nuts") {
1084+
} else if (update_method == nuts) {
10851085
SamplerResult result = update_nuts_bgm(
10861086
main_effects, pairwise_effects, inclusion_indicator,
10871087
observations, num_categories, counts_per_category, blume_capel_stats,
@@ -1171,7 +1171,7 @@ Rcpp::List run_gibbs_sampler_bgm(
11711171
arma::imat observations,
11721172
const arma::ivec& num_categories,
11731173
const double pairwise_scale,
1174-
const std::string& edge_prior,
1174+
const EdgePrior edge_prior,
11751175
arma::mat inclusion_probability,
11761176
const double beta_bernoulli_alpha,
11771177
const double beta_bernoulli_beta,
@@ -1189,7 +1189,7 @@ Rcpp::List run_gibbs_sampler_bgm(
11891189
const arma::uvec& is_ordinal_variable,
11901190
const arma::ivec& baseline_category,
11911191
bool edge_selection,
1192-
const std::string& update_method,
1192+
const UpdateMethod update_method,
11931193
const arma::imat pairwise_effect_indices,
11941194
const double target_accept,
11951195
arma::imat pairwise_stats,
@@ -1223,7 +1223,7 @@ Rcpp::List run_gibbs_sampler_bgm(
12231223
if (edge_selection) {
12241224
indicator_samples.set_size(iter, num_pairwise);
12251225
}
1226-
if (edge_selection && edge_prior == "Stochastic-Block") {
1226+
if (edge_selection && edge_prior == Stochastic_Block) {
12271227
allocation_samples.set_size(iter, num_variables);
12281228
}
12291229

@@ -1245,7 +1245,7 @@ Rcpp::List run_gibbs_sampler_bgm(
12451245
arma::vec log_Vn(1);
12461246

12471247
// --- Initialize SBM prior if applicable
1248-
if (edge_prior == "Stochastic-Block") {
1248+
if (edge_prior == Stochastic_Block) {
12491249
cluster_allocations[0] = 0;
12501250
cluster_allocations[1] = 1;
12511251
for (int i = 2; i < num_variables; i++) {
@@ -1273,7 +1273,7 @@ Rcpp::List run_gibbs_sampler_bgm(
12731273

12741274
// --- Optional HMC/NUTS warmup stage
12751275
double initial_step_size_joint = 1.0;
1276-
if (update_method == "hamiltonian-mc" || update_method == "nuts") {
1276+
if (update_method == hamiltonian_mc || update_method == nuts) {
12771277
initial_step_size_joint = find_initial_stepsize_bgm(
12781278
main_effects, pairwise_effects, inclusion_indicator, observations,
12791279
num_categories, counts_per_category, blume_capel_stats,
@@ -1283,7 +1283,7 @@ Rcpp::List run_gibbs_sampler_bgm(
12831283
}
12841284

12851285
// --- Warmup scheduling + adaptation controller
1286-
WarmupSchedule warmup_schedule(warmup, edge_selection, (update_method != "adaptive-metropolis"));
1286+
WarmupSchedule warmup_schedule(warmup, edge_selection, (update_method != adaptive_metropolis));
12871287
HMCAdaptationController adapt_joint(
12881288
num_main + num_pairwise, initial_step_size_joint, target_accept,
12891289
warmup_schedule, learn_mass_matrix
@@ -1339,7 +1339,7 @@ Rcpp::List run_gibbs_sampler_bgm(
13391339

13401340
// --- Update edge probabilities under the prior (if edge selection is active)
13411341
if (warmup_schedule.selection_enabled(iteration)) {
1342-
if (edge_prior == "Beta-Bernoulli") {
1342+
if (edge_prior == Beta_Bernoulli) {
13431343
int num_edges_included = 0;
13441344
for (int i = 0; i < num_variables - 1; i++)
13451345
for (int j = i + 1; j < num_variables; j++)
@@ -1354,7 +1354,7 @@ Rcpp::List run_gibbs_sampler_bgm(
13541354
for (int j = i + 1; j < num_variables; j++)
13551355
inclusion_probability(i, j) = inclusion_probability(j, i) = prob;
13561356

1357-
} else if (edge_prior == "Stochastic-Block") {
1357+
} else if (edge_prior == Stochastic_Block) {
13581358
cluster_allocations = block_allocations_mfm_sbm(
13591359
cluster_allocations, num_variables, log_Vn, cluster_prob,
13601360
arma::conv_to<arma::umat>::from(inclusion_indicator), dirichlet_alpha,
@@ -1396,7 +1396,7 @@ Rcpp::List run_gibbs_sampler_bgm(
13961396
}
13971397
}
13981398

1399-
if (edge_selection && edge_prior == "Stochastic-Block") {
1399+
if (edge_selection && edge_prior == Stochastic_Block) {
14001400
for (int v = 0; v < num_variables; v++) {
14011401
allocation_samples(sample_index, v) = cluster_allocations[v] + 1;
14021402
}
@@ -1408,7 +1408,7 @@ Rcpp::List run_gibbs_sampler_bgm(
14081408
out["main_samples"] = main_effect_samples;
14091409
out["pairwise_samples"] = pairwise_effect_samples;
14101410

1411-
if (update_method == "nuts") {
1411+
if (update_method == nuts) {
14121412
out["treedepth__"] = treedepth_samples;
14131413
out["divergent__"] = divergent_samples;
14141414
out["energy__"] = energy_samples;
@@ -1418,7 +1418,7 @@ Rcpp::List run_gibbs_sampler_bgm(
14181418
out["indicator_samples"] = indicator_samples;
14191419
}
14201420

1421-
if (edge_selection && edge_prior == "Stochastic-Block") {
1421+
if (edge_selection && edge_prior == Stochastic_Block) {
14221422
out["allocations"] = allocation_samples;
14231423
}
14241424

src/bgm_sampler.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22
#include <RcppArmadillo.h>
3+
#include "common_helpers.h"
34
// forward declaration
45
struct SafeRNG;
56
class ProgressManager;
@@ -9,7 +10,7 @@ Rcpp::List run_gibbs_sampler_bgm(
910
arma::imat observations,
1011
const arma::ivec& num_categories,
1112
const double pairwise_scale,
12-
const std::string& edge_prior,
13+
const EdgePrior edge_prior,
1314
arma::mat inclusion_probability,
1415
const double beta_bernoulli_alpha,
1516
const double beta_bernoulli_beta,
@@ -27,7 +28,7 @@ Rcpp::List run_gibbs_sampler_bgm(
2728
const arma::uvec& is_ordinal_variable,
2829
const arma::ivec& baseline_category,
2930
bool edge_selection,
30-
const std::string& update_method,
31+
const UpdateMethod update_method,
3132
const arma::imat pairwise_effect_indices,
3233
const double target_accept,
3334
arma::imat pairwise_stats,

src/common_helpers.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,30 @@ inline int count_num_main_effects(const arma::ivec& num_categories,
2525
}
2626
return n_params;
2727
}
28+
29+
enum UpdateMethod { adaptive_metropolis, hamiltonian_mc, nuts };
30+
31+
inline UpdateMethod update_method_from_string(const std::string& update_method) {
32+
if (update_method == "adaptive-metropolis")
33+
return adaptive_metropolis;
34+
35+
if (update_method == "hamiltonian-mc")
36+
return hamiltonian_mc;
37+
38+
if (update_method == "nuts")
39+
return nuts;
40+
41+
throw std::invalid_argument("Invalid update_method: " + update_method);
42+
}
43+
44+
enum EdgePrior { Stochastic_Block, Beta_Bernoulli };
45+
46+
inline EdgePrior edge_prior_from_string(const std::string& edge_prior) {
47+
if (edge_prior == "stochastic-block")
48+
return Stochastic_Block;
49+
50+
if (edge_prior == "Beta-Bernoulli")
51+
return Beta_Bernoulli;
52+
53+
throw std::invalid_argument("Invalid edge_prior: " + edge_prior);
54+
}

0 commit comments

Comments
 (0)