diff --git a/R/output_utils.R b/R/output_utils.R index 07007ead..f522fa63 100644 --- a/R/output_utils.R +++ b/R/output_utils.R @@ -99,20 +99,19 @@ prepare_output_bgm = function ( colnames(results$inclusion_indicator_samples) = edge_names } + results$arguments = arguments + class(results) = "bgms" + # SBM postprocessing if (edge_selection && edge_prior == "Stochastic-Block" && "allocations" %in% names(out)) { - results$allocations <- out$allocations + results$arguments$allocations = out$allocations # Requires that summarySBM() is available in namespace - sbm_summary <- summarySBM(list( - indicator = results$inclusion_indicator, - allocations = results$allocations - ), internal_call = TRUE) - results$components <- sbm_summary$components - results$allocations <- sbm_summary$allocations + sbm_summary = summarySBM(results, internal_call = TRUE) + results$components = sbm_summary$components + results$allocations = sbm_summary$allocations } - results$arguments = arguments - class(results) = "bgms" + return(results) } diff --git a/R/posterior_utils.R b/R/posterior_utils.R index 71336b78..d51fc0fb 100644 --- a/R/posterior_utils.R +++ b/R/posterior_utils.R @@ -121,10 +121,9 @@ summarySBM = function( if(arguments$save == FALSE && internal_call == FALSE) stop('The bgm function must be run with save = TRUE.') - cluster_allocations = bgm_object$allocations + cluster_allocations = arguments$allocations dirichlet_alpha = arguments$dirichlet_alpha lambda = arguments$lambda - # Pre-compute log_Vn for computing the cluster probabilities num_variables = ncol(cluster_allocations) log_Vn = compute_Vn_mfm_sbm( diff --git a/src/gibbs_functions_edge_prior.cpp b/src/gibbs_functions_edge_prior.cpp index 3650acfd..4155f7de 100644 --- a/src/gibbs_functions_edge_prior.cpp +++ b/src/gibbs_functions_edge_prior.cpp @@ -62,30 +62,16 @@ arma::vec compute_Vn_mfm_sbm(arma::uword no_variables, double r; double tmp; - for(arma::uword t = 1; t <= t_max; t++) { + for(arma::uword t = 0; t < t_max; t++) { r = -INFINITY; - for(arma::uword k = t; k < 500; k++) { - tmp = 0.0; - for(arma::uword tt = 1 - t; tt < 1; tt ++) { - tmp += std::log(dirichlet_alpha * k + tt); - } - for(arma::uword n = 0; n < no_variables; n++) { - tmp -= std::log(dirichlet_alpha * k + n); - } - tmp -= std::lgamma(k + 1); - - // Add the poisson term - double log_norm_factor = log(1.0 - exp(R::dpois(0, lambda, true))); - tmp += R::dpois(k-1, lambda, true) - log_norm_factor; - - // Compute the maximum between r and tmp - if (tmp > r) { - r = std::log(std::exp(r - tmp) + 1) + tmp; - } else { - r = std::log(1 + std::exp(tmp - r)) + r; - } + for(arma::uword k = t; k <= 500; k++){ + arma::vec b_linspace_1 = arma::linspace(k-t+1,k+1,t+1); + arma::vec b_linspace_2 = arma::linspace((k+1)*dirichlet_alpha,(k+1)*dirichlet_alpha+no_variables-1, no_variables); + double b = arma::accu(arma::log(b_linspace_1))-arma::accu(arma::log(b_linspace_2)) + R::dpois((k+1)-1, lambda, true); + double m = std::max(b,r); + r = std::log(std::exp(r-m) + std::exp(b-m)) + m; } - log_Vn(t-1) = r - std::log(std::exp(1) - 1); + log_Vn(t) = r; } return log_Vn; } @@ -129,24 +115,20 @@ double log_marginal_mfm_sbm(arma::uvec cluster_assign, double beta_bernoulli_alpha, double beta_bernoulli_beta) { - arma::uword no_clusters_excl_node = arma::max(cluster_assign); - //*std::max_element(cluster_assign.begin(), cluster_assign.end()); + arma::uvec indices = arma::regspace(0, no_variables-1); + arma::uvec select_variables = indices(arma::find(indices != node)); + arma::uvec cluster_assign_wo_node = cluster_assign(select_variables); + arma::uvec indicator_node = indicator.col(node); + arma::vec gamma_node = arma::conv_to::from(indicator_node(select_variables)); + arma::uvec table_cluster = table_cpp(cluster_assign_wo_node); double output = 0; - double sumG; - double sumN; - - output -= no_clusters_excl_node * R::lbeta(beta_bernoulli_alpha, beta_bernoulli_beta); - - for(arma::uword c = 0; c < no_clusters_excl_node; c++) { - sumG = 0; - sumN = 0; - for(arma::uword i = 0; i < no_variables; i++) { - if(cluster_assign(i) == c) { - sumG += static_cast(indicator(node, i)); - sumN += 1.0; - } + for(arma::uword i = 0; i < table_cluster.n_elem; i++){ + if(table_cluster(i) > 0){ // if the cluster is empty (it is the previous cluster of 'node' where 'node' was the only member - a singleton, thus skip) + arma::uvec which_variables_cluster_i = arma::find(cluster_assign_wo_node == i); + double sumG = arma::accu(gamma_node(which_variables_cluster_i)); + double sumN = static_cast(which_variables_cluster_i.n_elem); + output += R::lbeta(sumG + beta_bernoulli_alpha, sumN - sumG + beta_bernoulli_beta) - R::lbeta(beta_bernoulli_alpha, beta_bernoulli_beta); } - output += R::lbeta(sumG + beta_bernoulli_alpha, sumN - sumG + beta_bernoulli_beta); } return output; }