Skip to content

Commit 12ef013

Browse files
authored
Merge pull request #66 from FMKerckhof/master
Potential fixes for #48 (assisted by copilot)
2 parents be87b58 + 76cad3c commit 12ef013

File tree

11 files changed

+211
-9
lines changed

11 files changed

+211
-9
lines changed

.Rbuildignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@
55
^_pkgdown\.yml$
66
^docs$
77
^pkgdown$
8+
^.*\.Rproj$
9+
^\.Rproj\.user$

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
docs/
22
.Rhistory
33
docs
4+
.Rproj.user

DESCRIPTION

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ Type: Package
33
Title: Gaussian Mixture Models, K-Means, Mini-Batch-Kmeans, K-Medoids and Affinity Propagation Clustering
44
Version: 1.3.4
55
Date: 2025-09-14
6-
Authors@R: c( person(given = "Lampros", family = "Mouselimis", email = "[email protected]", role = c("aut", "cre"), comment = c(ORCID = "0000-0002-8024-1546")), person(given = "Conrad", family = "Sanderson", role = "cph", comment = "Author of the C++ Armadillo library"), person(given = "Ryan", family = "Curtin", role = "cph", comment = "Author of the C++ Armadillo library"), person(given = "Siddharth", family = "Agrawal", role = "cph", comment = "Author of the C code of the Mini-Batch-Kmeans algorithm (https://github.com/siddharth-agrawal/Mini-Batch-K-Means)"), person(given = "Brendan", family = "Frey", email = "[email protected]", role = "cph", comment = "Author of the matlab code of the Affinity propagation algorithm (for commercial use please contact the author of the matlab code)"), person(given = "Delbert", family = "Dueck", role = "cph", comment = "Author of the matlab code of the Affinity propagation algorithm"), person(given = "Vitalie", family = "Spinu", email = "[email protected]", role = "ctb", comment = "Github Contributor") )
6+
Authors@R: c( person(given = "Lampros", family = "Mouselimis", email = "[email protected]", role = c("aut", "cre"), comment = c(ORCID = "0000-0002-8024-1546")), person(given = "Conrad", family = "Sanderson", role = "cph", comment = "Author of the C++ Armadillo library"), person(given = "Ryan", family = "Curtin", role = "cph", comment = "Author of the C++ Armadillo library"), person(given = "Siddharth", family = "Agrawal", role = "cph", comment = "Author of the C code of the Mini-Batch-Kmeans algorithm (https://github.com/siddharth-agrawal/Mini-Batch-K-Means)"), person(given = "Brendan", family = "Frey", email = "[email protected]", role = "cph", comment = "Author of the matlab code of the Affinity propagation algorithm (for commercial use please contact the author of the matlab code)"), person(given = "Delbert", family = "Dueck", role = "cph", comment = "Author of the matlab code of the Affinity propagation algorithm"), person(given = "Vitalie", family = "Spinu", email = "[email protected]", role = "ctb", comment = "Github Contributor"),person(given = "Frederiek - Maarten", family = "Kerckhof", email = "[email protected]", role = "ctb", comment = "Github Contributor") )
77
BugReports: https://github.com/mlampros/ClusterR/issues
88
URL: https://github.com/mlampros/ClusterR, https://mlampros.github.io/ClusterR/
99
Description: Gaussian mixture models, k-means, mini-batch-kmeans, k-medoids and affinity propagation clustering with the option to plot, validate, predict (new data) and estimate the optimal number of clusters. The package takes advantage of 'RcppArmadillo' to speed up the computationally intensive parts of the functions. For more information, see (i) "Clustering in an Object-Oriented Environment" by Anja Struyf, Mia Hubert, Peter Rousseeuw (1997), Journal of Statistical Software, <doi:10.18637/jss.v001.i04>; (ii) "Web-scale k-means clustering" by D. Sculley (2010), ACM Digital Library, <doi:10.1145/1772690.1772862>; (iii) "Armadillo: a template-based C++ library for linear algebra" by Sanderson et al (2016), The Journal of Open Source Software, <doi:10.21105/joss.00026>; (iv) "Clustering by Passing Messages Between Data Points" by Brendan J. Frey and Delbert Dueck, Science 16 Feb 2007: Vol. 315, Issue 5814, pp. 972-976, <doi:10.1126/science.1136800>.
@@ -30,4 +30,4 @@ Suggests:
3030
knitr,
3131
rmarkdown
3232
VignetteBuilder: knitr
33-
RoxygenNote: 7.3.2
33+
RoxygenNote: 7.3.3

R/RcppExports.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ predict_MGausDPDF <- function(data, CENTROIDS, COVARIANCE, WEIGHTS, eps = 1.0e-8
4949
.Call(`_ClusterR_predict_MGausDPDF`, data, CENTROIDS, COVARIANCE, WEIGHTS, eps)
5050
}
5151

52+
predict_MGausDPDF_full <- function(data, CENTROIDS, COVARIANCE, WEIGHTS, eps = 1.0e-8) {
53+
.Call(`_ClusterR_predict_MGausDPDF_full`, data, CENTROIDS, COVARIANCE, WEIGHTS, eps)
54+
}
55+
5256
GMM_arma_AIC_BIC <- function(data, max_clusters, dist_mode, seed_mode, km_iter, em_iter, verbose, var_floor = 1e-10, criterion = "AIC", seed = 1L) {
5357
.Call(`_ClusterR_GMM_arma_AIC_BIC`, data, max_clusters, dist_mode, seed_mode, km_iter, em_iter, verbose, var_floor, criterion, seed)
5458
}

R/clustering_functions.R

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,12 +131,13 @@ GMM = function(data,
131131
#'
132132
#' @param data matrix or data frame
133133
#' @param CENTROIDS matrix or data frame containing the centroids (means), stored as row vectors
134-
#' @param COVARIANCE matrix or data frame containing the diagonal covariance matrices, stored as row vectors
134+
#' @param COVARIANCE matrix or data frame (for diagonal covariance) or 3D array (for full covariance matrices)
135135
#' @param WEIGHTS vector containing the weights
136136
#' @return a list consisting of the log-likelihoods, cluster probabilities and cluster labels.
137137
#' @author Lampros Mouselimis
138138
#' @details
139139
#' This function takes the centroids, covariance matrix and weights from a trained model and returns the log-likelihoods, cluster probabilities and cluster labels for new data.
140+
#' The function handles both diagonal covariance matrices (2D matrix) and full covariance matrices (3D array/cube).
140141
#' @export
141142
#' @examples
142143
#'
@@ -156,18 +157,34 @@ predict_GMM = function(data, CENTROIDS, COVARIANCE, WEIGHTS) {
156157
if (!inherits(data, 'matrix')) stop('data should be either a matrix or a data frame')
157158
if ('data.frame' %in% class(CENTROIDS)) CENTROIDS = as.matrix(CENTROIDS)
158159
if (!inherits(CENTROIDS, 'matrix')) stop('CENTROIDS should be either a matrix or a data frame')
159-
if ('data.frame' %in% class(COVARIANCE)) COVARIANCE = as.matrix(COVARIANCE)
160-
if (!inherits(COVARIANCE, 'matrix')) stop('COVARIANCE should be either a matrix or a data frame')
161-
if (ncol(data) != ncol(CENTROIDS) || ncol(data) != ncol(COVARIANCE) || length(WEIGHTS) != nrow(CENTROIDS) || length(WEIGHTS) != nrow(COVARIANCE))
162-
stop('the number of columns of the data, CENTROIDS and COVARIANCE should match and the number of rows of the CENTROIDS AND COVARIANCE should be equal to the length of the WEIGHTS vector')
163160
if (!inherits(WEIGHTS, 'numeric') || !is.vector(WEIGHTS))
164161
stop('WEIGHTS should be a numeric vector')
165162

166163
flag_non_finite = check_NaN_Inf(data)
167164

168165
if (!flag_non_finite) stop("the data includes NaN's or +/- Inf values")
169166

170-
res = predict_MGausDPDF(data, CENTROIDS, COVARIANCE, WEIGHTS, eps = 1.0e-8)
167+
# Check if COVARIANCE is a 3D array (full covariance) or 2D matrix (diagonal covariance)
168+
is_full_covariance = length(dim(COVARIANCE)) == 3
169+
170+
if (is_full_covariance) {
171+
# Full covariance matrices - COVARIANCE is a 3D array (cube)
172+
if (!is.array(COVARIANCE)) stop('COVARIANCE should be a 3D array for full covariance matrices')
173+
if (dim(COVARIANCE)[1] != ncol(data) || dim(COVARIANCE)[2] != ncol(data) || dim(COVARIANCE)[3] != length(WEIGHTS))
174+
stop('for full covariance: dim(COVARIANCE)[1] and dim(COVARIANCE)[2] should equal ncol(data), and dim(COVARIANCE)[3] should equal length(WEIGHTS)')
175+
if (ncol(data) != ncol(CENTROIDS) || length(WEIGHTS) != nrow(CENTROIDS))
176+
stop('the number of columns of the data and CENTROIDS should match and the number of rows of CENTROIDS should equal the length of the WEIGHTS vector')
177+
178+
res = predict_MGausDPDF_full(data, CENTROIDS, COVARIANCE, WEIGHTS, eps = 1.0e-8)
179+
} else {
180+
# Diagonal covariance matrices - COVARIANCE is a 2D matrix
181+
if ('data.frame' %in% class(COVARIANCE)) COVARIANCE = as.matrix(COVARIANCE)
182+
if (!inherits(COVARIANCE, 'matrix')) stop('COVARIANCE should be either a matrix or a data frame for diagonal covariance')
183+
if (ncol(data) != ncol(CENTROIDS) || ncol(data) != ncol(COVARIANCE) || length(WEIGHTS) != nrow(CENTROIDS) || length(WEIGHTS) != nrow(COVARIANCE))
184+
stop('the number of columns of the data, CENTROIDS and COVARIANCE should match and the number of rows of the CENTROIDS AND COVARIANCE should be equal to the length of the WEIGHTS vector')
185+
186+
res = predict_MGausDPDF(data, CENTROIDS, COVARIANCE, WEIGHTS, eps = 1.0e-8)
187+
}
171188

172189
# I've added 1 to the output cluster labels to account for the difference in indexing between R and C++
173190
list(log_likelihood = res$Log_likelihood_raw,

inst/include/ClusterRHeader.h

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1749,6 +1749,80 @@ namespace clustR {
17491749

17501750

17511751

1752+
// predict function for full covariance matrices (3D cube)
1753+
// This handles the case when GMM is fitted with full_covariance_matrices=TRUE
1754+
//
1755+
1756+
Rcpp::List predict_MGausDPDF_full(arma::mat data, arma::mat CENTROIDS, arma::cube COVARIANCE, arma::vec WEIGHTS, double eps = 1.0e-8) {
1757+
1758+
arma::mat gaus_mat(data.n_rows, WEIGHTS.n_elem, arma::fill::zeros);
1759+
1760+
arma::mat gaus_mat_log_lik(data.n_rows, WEIGHTS.n_elem, arma::fill::zeros);
1761+
1762+
for (unsigned int i = 0; i < WEIGHTS.n_elem; i++) {
1763+
1764+
arma::vec gaus_vec(data.n_rows, arma::fill::zeros);
1765+
1766+
arma::vec gaus_vec_log(data.n_rows, arma::fill::zeros);
1767+
1768+
// Extract the full covariance matrix for this Gaussian component
1769+
arma::mat tmp_cov_mt = COVARIANCE.slice(i);
1770+
1771+
double tmp_determinant = arma::det(tmp_cov_mt);
1772+
1773+
if (tmp_determinant == 0.0) {
1774+
1775+
Rcpp::stop("the determinant is zero or approximately zero. The data might include highly correlated variables or variables with low variance");
1776+
}
1777+
1778+
// Compute inverse of covariance matrix
1779+
arma::mat inv_cov_mt = arma::inv(tmp_cov_mt);
1780+
1781+
for (unsigned int j = 0; j < data.n_rows; j++) {
1782+
1783+
double n = data.n_cols;
1784+
1785+
arma::vec tmp_vec = (arma::conv_to< arma::vec >::from(data.row(j)) - arma::conv_to< arma::vec >::from(CENTROIDS.row(i)));
1786+
1787+
double tmp_val = 1.0 / std::sqrt(std::pow(2.0 * arma::datum::pi, n) * tmp_determinant);
1788+
1789+
double inner_likelih = 0.5 * (arma::as_scalar(tmp_vec.t() * inv_cov_mt * tmp_vec));
1790+
1791+
gaus_vec_log(j) = -(n / 2.0) * std::log(2.0 * arma::datum::pi) - (1.0 / 2.0) * (std::log(tmp_determinant)) - inner_likelih;
1792+
1793+
gaus_vec(j) = tmp_val * std::exp(-inner_likelih);
1794+
}
1795+
1796+
gaus_mat.col(i) = arma::as_scalar(WEIGHTS(i)) * gaus_vec;
1797+
1798+
gaus_mat_log_lik.col(i) = gaus_vec_log;
1799+
}
1800+
1801+
arma::mat loglik1(data.n_rows, WEIGHTS.n_elem, arma::fill::zeros);
1802+
1803+
arma::rowvec loglik2(data.n_rows, arma::fill::zeros);
1804+
1805+
for (unsigned int j = 0; j < loglik1.n_rows; j++) {
1806+
1807+
arma::rowvec tmp_vec = arma::conv_to< arma::rowvec >::from(gaus_mat.row(j)) + eps;
1808+
1809+
tmp_vec /= arma::sum(tmp_vec); // normalize row-data to get probabilities
1810+
1811+
loglik1.row(j) = tmp_vec; // assign probabilities
1812+
1813+
arma::uvec log_lik_label = arma::find(tmp_vec == arma::max(tmp_vec));
1814+
1815+
loglik2(j) = arma::as_scalar(log_lik_label(0)); // assign labels
1816+
}
1817+
1818+
return Rcpp::List::create( Rcpp::Named("Log_likelihood_raw") = gaus_mat_log_lik,
1819+
Rcpp::Named("cluster_proba") = loglik1,
1820+
Rcpp::Named("cluster_labels") = loglik2 );
1821+
}
1822+
1823+
1824+
1825+
17521826
// function to calculate bic-aic
17531827
//
17541828

man/predict_GMM.Rd

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/RcppExports.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,21 @@ BEGIN_RCPP
199199
return rcpp_result_gen;
200200
END_RCPP
201201
}
202+
// predict_MGausDPDF_full
203+
Rcpp::List predict_MGausDPDF_full(arma::mat data, arma::mat CENTROIDS, arma::cube COVARIANCE, arma::vec WEIGHTS, double eps);
204+
RcppExport SEXP _ClusterR_predict_MGausDPDF_full(SEXP dataSEXP, SEXP CENTROIDSSEXP, SEXP COVARIANCESEXP, SEXP WEIGHTSSEXP, SEXP epsSEXP) {
205+
BEGIN_RCPP
206+
Rcpp::RObject rcpp_result_gen;
207+
Rcpp::RNGScope rcpp_rngScope_gen;
208+
Rcpp::traits::input_parameter< arma::mat >::type data(dataSEXP);
209+
Rcpp::traits::input_parameter< arma::mat >::type CENTROIDS(CENTROIDSSEXP);
210+
Rcpp::traits::input_parameter< arma::cube >::type COVARIANCE(COVARIANCESEXP);
211+
Rcpp::traits::input_parameter< arma::vec >::type WEIGHTS(WEIGHTSSEXP);
212+
Rcpp::traits::input_parameter< double >::type eps(epsSEXP);
213+
rcpp_result_gen = Rcpp::wrap(predict_MGausDPDF_full(data, CENTROIDS, COVARIANCE, WEIGHTS, eps));
214+
return rcpp_result_gen;
215+
END_RCPP
216+
}
202217
// GMM_arma_AIC_BIC
203218
arma::rowvec GMM_arma_AIC_BIC(arma::mat& data, arma::rowvec max_clusters, std::string dist_mode, std::string seed_mode, int km_iter, int em_iter, bool verbose, double var_floor, std::string criterion, int seed);
204219
RcppExport SEXP _ClusterR_GMM_arma_AIC_BIC(SEXP dataSEXP, SEXP max_clustersSEXP, SEXP dist_modeSEXP, SEXP seed_modeSEXP, SEXP km_iterSEXP, SEXP em_iterSEXP, SEXP verboseSEXP, SEXP var_floorSEXP, SEXP criterionSEXP, SEXP seedSEXP) {

src/export_inst_folder_headers.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,21 @@ Rcpp::List predict_MGausDPDF(arma::mat data, arma::mat CENTROIDS, arma::mat COVA
294294

295295

296296

297+
// predict function for full covariance matrices (3D cube)
298+
// This handles the case when GMM is fitted with full_covariance_matrices=TRUE
299+
//
300+
301+
// [[Rcpp::export]]
302+
Rcpp::List predict_MGausDPDF_full(arma::mat data, arma::mat CENTROIDS, arma::cube COVARIANCE, arma::vec WEIGHTS, double eps = 1.0e-8) {
303+
304+
ClustHeader CRH;
305+
306+
return CRH.predict_MGausDPDF_full(data, CENTROIDS, COVARIANCE, WEIGHTS, eps);
307+
}
308+
309+
310+
311+
297312
// function to calculate bic-aic
298313
//
299314

src/init.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ extern SEXP _ClusterR_opt_clust_fK(SEXP, SEXP, SEXP);
2525
extern SEXP _ClusterR_OptClust(SEXP, SEXP, SEXP, SEXP, SEXP, SEXP, SEXP, SEXP, SEXP, SEXP, SEXP, SEXP);
2626
extern SEXP _ClusterR_predict_medoids(SEXP, SEXP, SEXP, SEXP, SEXP, SEXP, SEXP);
2727
extern SEXP _ClusterR_predict_MGausDPDF(SEXP, SEXP, SEXP, SEXP, SEXP);
28+
extern SEXP _ClusterR_predict_MGausDPDF_full(SEXP, SEXP, SEXP, SEXP, SEXP);
2829
extern SEXP _ClusterR_Predict_mini_batch_kmeans(SEXP, SEXP, SEXP, SEXP);
2930
extern SEXP _ClusterR_preferenceRange(SEXP, SEXP, SEXP);
3031
extern SEXP _ClusterR_SCALE(SEXP, SEXP, SEXP);
@@ -50,6 +51,7 @@ static const R_CallMethodDef CallEntries[] = {
5051
{"_ClusterR_OptClust", (DL_FUNC) &_ClusterR_OptClust, 12},
5152
{"_ClusterR_predict_medoids", (DL_FUNC) &_ClusterR_predict_medoids, 7},
5253
{"_ClusterR_predict_MGausDPDF", (DL_FUNC) &_ClusterR_predict_MGausDPDF, 5},
54+
{"_ClusterR_predict_MGausDPDF_full", (DL_FUNC) &_ClusterR_predict_MGausDPDF_full, 5},
5355
{"_ClusterR_Predict_mini_batch_kmeans", (DL_FUNC) &_ClusterR_Predict_mini_batch_kmeans, 4},
5456
{"_ClusterR_preferenceRange", (DL_FUNC) &_ClusterR_preferenceRange, 3},
5557
{"_ClusterR_SCALE", (DL_FUNC) &_ClusterR_SCALE, 3},

0 commit comments

Comments
 (0)