Skip to content

Commit 0f03be4

Browse files
authored
Merge pull request #47 from jupepis/adaMala
SBM prior - correction Armadillo routines
2 parents 2b3115d + 270870d commit 0f03be4

File tree

3 files changed

+29
-49
lines changed

3 files changed

+29
-49
lines changed

R/output_utils.R

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -99,20 +99,19 @@ prepare_output_bgm = function (
9999
colnames(results$inclusion_indicator_samples) = edge_names
100100
}
101101

102+
results$arguments = arguments
103+
class(results) = "bgms"
104+
102105
# SBM postprocessing
103106
if (edge_selection && edge_prior == "Stochastic-Block" && "allocations" %in% names(out)) {
104-
results$allocations <- out$allocations
107+
results$arguments$allocations = out$allocations
105108
# Requires that summarySBM() is available in namespace
106-
sbm_summary <- summarySBM(list(
107-
indicator = results$inclusion_indicator,
108-
allocations = results$allocations
109-
), internal_call = TRUE)
110-
results$components <- sbm_summary$components
111-
results$allocations <- sbm_summary$allocations
109+
sbm_summary = summarySBM(results, internal_call = TRUE)
110+
results$components = sbm_summary$components
111+
results$allocations = sbm_summary$allocations
112112
}
113113

114-
results$arguments = arguments
115-
class(results) = "bgms"
114+
116115
return(results)
117116
}
118117

R/posterior_utils.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,9 @@ summarySBM = function(
121121
if(arguments$save == FALSE && internal_call == FALSE)
122122
stop('The bgm function must be run with save = TRUE.')
123123

124-
cluster_allocations = bgm_object$allocations
124+
cluster_allocations = arguments$allocations
125125
dirichlet_alpha = arguments$dirichlet_alpha
126126
lambda = arguments$lambda
127-
128127
# Pre-compute log_Vn for computing the cluster probabilities
129128
num_variables = ncol(cluster_allocations)
130129
log_Vn = compute_Vn_mfm_sbm(

src/gibbs_functions_edge_prior.cpp

Lines changed: 20 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -62,30 +62,16 @@ arma::vec compute_Vn_mfm_sbm(arma::uword no_variables,
6262
double r;
6363
double tmp;
6464

65-
for(arma::uword t = 1; t <= t_max; t++) {
65+
for(arma::uword t = 0; t < t_max; t++) {
6666
r = -INFINITY;
67-
for(arma::uword k = t; k < 500; k++) {
68-
tmp = 0.0;
69-
for(arma::uword tt = 1 - t; tt < 1; tt ++) {
70-
tmp += std::log(dirichlet_alpha * k + tt);
71-
}
72-
for(arma::uword n = 0; n < no_variables; n++) {
73-
tmp -= std::log(dirichlet_alpha * k + n);
74-
}
75-
tmp -= std::lgamma(k + 1);
76-
77-
// Add the poisson term
78-
double log_norm_factor = log(1.0 - exp(R::dpois(0, lambda, true)));
79-
tmp += R::dpois(k-1, lambda, true) - log_norm_factor;
80-
81-
// Compute the maximum between r and tmp
82-
if (tmp > r) {
83-
r = std::log(std::exp(r - tmp) + 1) + tmp;
84-
} else {
85-
r = std::log(1 + std::exp(tmp - r)) + r;
86-
}
67+
for(arma::uword k = t; k <= 500; k++){
68+
arma::vec b_linspace_1 = arma::linspace(k-t+1,k+1,t+1);
69+
arma::vec b_linspace_2 = arma::linspace((k+1)*dirichlet_alpha,(k+1)*dirichlet_alpha+no_variables-1, no_variables);
70+
double b = arma::accu(arma::log(b_linspace_1))-arma::accu(arma::log(b_linspace_2)) + R::dpois((k+1)-1, lambda, true);
71+
double m = std::max(b,r);
72+
r = std::log(std::exp(r-m) + std::exp(b-m)) + m;
8773
}
88-
log_Vn(t-1) = r - std::log(std::exp(1) - 1);
74+
log_Vn(t) = r;
8975
}
9076
return log_Vn;
9177
}
@@ -129,24 +115,20 @@ double log_marginal_mfm_sbm(arma::uvec cluster_assign,
129115
double beta_bernoulli_alpha,
130116
double beta_bernoulli_beta) {
131117

132-
arma::uword no_clusters_excl_node = arma::max(cluster_assign);
133-
//*std::max_element(cluster_assign.begin(), cluster_assign.end());
118+
arma::uvec indices = arma::regspace<arma::uvec>(0, no_variables-1);
119+
arma::uvec select_variables = indices(arma::find(indices != node));
120+
arma::uvec cluster_assign_wo_node = cluster_assign(select_variables);
121+
arma::uvec indicator_node = indicator.col(node);
122+
arma::vec gamma_node = arma::conv_to<arma::vec>::from(indicator_node(select_variables));
123+
arma::uvec table_cluster = table_cpp(cluster_assign_wo_node);
134124
double output = 0;
135-
double sumG;
136-
double sumN;
137-
138-
output -= no_clusters_excl_node * R::lbeta(beta_bernoulli_alpha, beta_bernoulli_beta);
139-
140-
for(arma::uword c = 0; c < no_clusters_excl_node; c++) {
141-
sumG = 0;
142-
sumN = 0;
143-
for(arma::uword i = 0; i < no_variables; i++) {
144-
if(cluster_assign(i) == c) {
145-
sumG += static_cast<double>(indicator(node, i));
146-
sumN += 1.0;
147-
}
125+
for(arma::uword i = 0; i < table_cluster.n_elem; i++){
126+
if(table_cluster(i) > 0){ // if the cluster is empty (it is the previous cluster of 'node' where 'node' was the only member - a singleton, thus skip)
127+
arma::uvec which_variables_cluster_i = arma::find(cluster_assign_wo_node == i);
128+
double sumG = arma::accu(gamma_node(which_variables_cluster_i));
129+
double sumN = static_cast<double>(which_variables_cluster_i.n_elem);
130+
output += R::lbeta(sumG + beta_bernoulli_alpha, sumN - sumG + beta_bernoulli_beta) - R::lbeta(beta_bernoulli_alpha, beta_bernoulli_beta);
148131
}
149-
output += R::lbeta(sumG + beta_bernoulli_alpha, sumN - sumG + beta_bernoulli_beta);
150132
}
151133
return output;
152134
}

0 commit comments

Comments
 (0)