Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions R/output_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,20 +99,19 @@ prepare_output_bgm = function (
colnames(results$inclusion_indicator_samples) = edge_names
}

results$arguments = arguments
class(results) = "bgms"

# SBM postprocessing
if (edge_selection && edge_prior == "Stochastic-Block" && "allocations" %in% names(out)) {
results$allocations <- out$allocations
results$arguments$allocations = out$allocations
# Requires that summarySBM() is available in namespace
sbm_summary <- summarySBM(list(
indicator = results$inclusion_indicator,
allocations = results$allocations
), internal_call = TRUE)
results$components <- sbm_summary$components
results$allocations <- sbm_summary$allocations
sbm_summary = summarySBM(results, internal_call = TRUE)
results$components = sbm_summary$components
results$allocations = sbm_summary$allocations
}

results$arguments = arguments
class(results) = "bgms"

return(results)
}

Expand Down
3 changes: 1 addition & 2 deletions R/posterior_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,9 @@ summarySBM = function(
if(arguments$save == FALSE && internal_call == FALSE)
stop('The bgm function must be run with save = TRUE.')

cluster_allocations = bgm_object$allocations
cluster_allocations = arguments$allocations
dirichlet_alpha = arguments$dirichlet_alpha
lambda = arguments$lambda

# Pre-compute log_Vn for computing the cluster probabilities
num_variables = ncol(cluster_allocations)
log_Vn = compute_Vn_mfm_sbm(
Expand Down
58 changes: 20 additions & 38 deletions src/gibbs_functions_edge_prior.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,30 +62,16 @@ arma::vec compute_Vn_mfm_sbm(arma::uword no_variables,
double r;
double tmp;

for(arma::uword t = 1; t <= t_max; t++) {
for(arma::uword t = 0; t < t_max; t++) {
r = -INFINITY;
for(arma::uword k = t; k < 500; k++) {
tmp = 0.0;
for(arma::uword tt = 1 - t; tt < 1; tt ++) {
tmp += std::log(dirichlet_alpha * k + tt);
}
for(arma::uword n = 0; n < no_variables; n++) {
tmp -= std::log(dirichlet_alpha * k + n);
}
tmp -= std::lgamma(k + 1);

// Add the poisson term
double log_norm_factor = log(1.0 - exp(R::dpois(0, lambda, true)));
tmp += R::dpois(k-1, lambda, true) - log_norm_factor;

// Compute the maximum between r and tmp
if (tmp > r) {
r = std::log(std::exp(r - tmp) + 1) + tmp;
} else {
r = std::log(1 + std::exp(tmp - r)) + r;
}
for(arma::uword k = t; k <= 500; k++){
arma::vec b_linspace_1 = arma::linspace(k-t+1,k+1,t+1);
arma::vec b_linspace_2 = arma::linspace((k+1)*dirichlet_alpha,(k+1)*dirichlet_alpha+no_variables-1, no_variables);
double b = arma::accu(arma::log(b_linspace_1))-arma::accu(arma::log(b_linspace_2)) + R::dpois((k+1)-1, lambda, true);
double m = std::max(b,r);
r = std::log(std::exp(r-m) + std::exp(b-m)) + m;
}
log_Vn(t-1) = r - std::log(std::exp(1) - 1);
log_Vn(t) = r;
}
return log_Vn;
}
Expand Down Expand Up @@ -129,24 +115,20 @@ double log_marginal_mfm_sbm(arma::uvec cluster_assign,
double beta_bernoulli_alpha,
double beta_bernoulli_beta) {

arma::uword no_clusters_excl_node = arma::max(cluster_assign);
//*std::max_element(cluster_assign.begin(), cluster_assign.end());
arma::uvec indices = arma::regspace<arma::uvec>(0, no_variables-1);
arma::uvec select_variables = indices(arma::find(indices != node));
arma::uvec cluster_assign_wo_node = cluster_assign(select_variables);
arma::uvec indicator_node = indicator.col(node);
arma::vec gamma_node = arma::conv_to<arma::vec>::from(indicator_node(select_variables));
arma::uvec table_cluster = table_cpp(cluster_assign_wo_node);
double output = 0;
double sumG;
double sumN;

output -= no_clusters_excl_node * R::lbeta(beta_bernoulli_alpha, beta_bernoulli_beta);

for(arma::uword c = 0; c < no_clusters_excl_node; c++) {
sumG = 0;
sumN = 0;
for(arma::uword i = 0; i < no_variables; i++) {
if(cluster_assign(i) == c) {
sumG += static_cast<double>(indicator(node, i));
sumN += 1.0;
}
for(arma::uword i = 0; i < table_cluster.n_elem; i++){
if(table_cluster(i) > 0){ // if the cluster is empty (it is the previous cluster of 'node' where 'node' was the only member - a singleton, thus skip)
arma::uvec which_variables_cluster_i = arma::find(cluster_assign_wo_node == i);
double sumG = arma::accu(gamma_node(which_variables_cluster_i));
double sumN = static_cast<double>(which_variables_cluster_i.n_elem);
output += R::lbeta(sumG + beta_bernoulli_alpha, sumN - sumG + beta_bernoulli_beta) - R::lbeta(beta_bernoulli_alpha, beta_bernoulli_beta);
}
output += R::lbeta(sumG + beta_bernoulli_alpha, sumN - sumG + beta_bernoulli_beta);
}
return output;
}
Expand Down