@@ -62,30 +62,16 @@ arma::vec compute_Vn_mfm_sbm(arma::uword no_variables,
6262 double r;
6363 double tmp;
6464
65- for (arma::uword t = 1 ; t <= t_max; t++) {
65+ for (arma::uword t = 0 ; t < t_max; t++) {
6666 r = -INFINITY;
67- for (arma::uword k = t; k < 500 ; k++) {
68- tmp = 0.0 ;
69- for (arma::uword tt = 1 - t; tt < 1 ; tt ++) {
70- tmp += std::log (dirichlet_alpha * k + tt);
71- }
72- for (arma::uword n = 0 ; n < no_variables; n++) {
73- tmp -= std::log (dirichlet_alpha * k + n);
74- }
75- tmp -= std::lgamma (k + 1 );
76-
77- // Add the poisson term
78- double log_norm_factor = log (1.0 - exp (R::dpois (0 , lambda, true )));
79- tmp += R::dpois (k-1 , lambda, true ) - log_norm_factor;
80-
81- // Compute the maximum between r and tmp
82- if (tmp > r) {
83- r = std::log (std::exp (r - tmp) + 1 ) + tmp;
84- } else {
85- r = std::log (1 + std::exp (tmp - r)) + r;
86- }
67+ for (arma::uword k = t; k <= 500 ; k++){
68+ arma::vec b_linspace_1 = arma::linspace (k-t+1 ,k+1 ,t+1 );
69+ arma::vec b_linspace_2 = arma::linspace ((k+1 )*dirichlet_alpha,(k+1 )*dirichlet_alpha+no_variables-1 , no_variables);
70+ double b = arma::accu (arma::log (b_linspace_1))-arma::accu (arma::log (b_linspace_2)) + R::dpois ((k+1 )-1 , lambda, true );
71+ double m = std::max (b,r);
72+ r = std::log (std::exp (r-m) + std::exp (b-m)) + m;
8773 }
88- log_Vn (t- 1 ) = r - std::log ( std::exp ( 1 ) - 1 ) ;
74+ log_Vn (t) = r;
8975 }
9076 return log_Vn;
9177}
@@ -129,24 +115,20 @@ double log_marginal_mfm_sbm(arma::uvec cluster_assign,
129115 double beta_bernoulli_alpha,
130116 double beta_bernoulli_beta) {
131117
132- arma::uword no_clusters_excl_node = arma::max (cluster_assign);
133- // *std::max_element(cluster_assign.begin(), cluster_assign.end());
118+ arma::uvec indices = arma::regspace<arma::uvec>(0 , no_variables-1 );
119+ arma::uvec select_variables = indices (arma::find (indices != node));
120+ arma::uvec cluster_assign_wo_node = cluster_assign (select_variables);
121+ arma::uvec indicator_node = indicator.col (node);
122+ arma::vec gamma_node = arma::conv_to<arma::vec>::from (indicator_node (select_variables));
123+ arma::uvec table_cluster = table_cpp (cluster_assign_wo_node);
134124 double output = 0 ;
135- double sumG;
136- double sumN;
137-
138- output -= no_clusters_excl_node * R::lbeta (beta_bernoulli_alpha, beta_bernoulli_beta);
139-
140- for (arma::uword c = 0 ; c < no_clusters_excl_node; c++) {
141- sumG = 0 ;
142- sumN = 0 ;
143- for (arma::uword i = 0 ; i < no_variables; i++) {
144- if (cluster_assign (i) == c) {
145- sumG += static_cast <double >(indicator (node, i));
146- sumN += 1.0 ;
147- }
125+ for (arma::uword i = 0 ; i < table_cluster.n_elem ; i++){
126+ if (table_cluster (i) > 0 ){ // if the cluster is empty (it is the previous cluster of 'node' where 'node' was the only member - a singleton, thus skip)
127+ arma::uvec which_variables_cluster_i = arma::find (cluster_assign_wo_node == i);
128+ double sumG = arma::accu (gamma_node (which_variables_cluster_i));
129+ double sumN = static_cast <double >(which_variables_cluster_i.n_elem );
130+ output += R::lbeta (sumG + beta_bernoulli_alpha, sumN - sumG + beta_bernoulli_beta) - R::lbeta (beta_bernoulli_alpha, beta_bernoulli_beta);
148131 }
149- output += R::lbeta (sumG + beta_bernoulli_alpha, sumN - sumG + beta_bernoulli_beta);
150132 }
151133 return output;
152134}
0 commit comments