1- from jax import numpy as jnp , random , jit
1+ from jax import numpy as jnp , random , jit , scipy
22from functools import partial
33import time , sys
44import numpy as np
5- #from sklearn import mixture
6- #from sklearn.cluster import KMeans
7- from scipy .stats import multivariate_normal
8- #from ngclearn.utils.stat_utils import calc_log_gauss_pdf
9- from ngclearn .utils .model_utils import softmax
10- #from kmeans import K_Means
11- from sklearn import mixture
12-
13- #seed = 69
14- #tf.random.set_seed(seed=seed)
15-
16- class GMM :
5+
6+ ########################################################################################################################
7+ ## internal routines for mixture model
8+ ########################################################################################################################
9+
10+ @partial (jit , static_argnums = [3 ])
11+ def _log_gaussian_pdf (X , mu , Sigma , use_chol_prec = True ):
12+ """
13+ Calculates the multivariate Gaussian log likelihood of a design matrix/dataset `X`, under a given parameter mean
14+ `mu` and parameter covariance `Sigma`.
15+
16+ Args:
17+ X: a design matrix (dataset) to compute the log likelihood of
18+ mu: a parameter mean vector
19+ Sigma: a parameter covariance matrix
20+ use_chol_prec: should this routine use Cholesky-factor computation of the precision (Default: True)
21+
22+ Returns:
23+ the log likelihood (scalar) of this design matrix X
24+ """
25+ D = mu .shape [1 ] * 1. ## get dimensionality
26+ if use_chol_prec : ## use Cholesky-factor calc of precision
27+ C = jnp .linalg .cholesky (Sigma ) # calc_prec_chol(mu, cov)
28+ inv_C = jnp .linalg .pinv (C )
29+ precision = jnp .matmul (inv_C .T , inv_C )
30+ else : ## use Moore-Penrose pseudo-inverse calc of precision
31+ precision = jnp .linalg .pinv (Sigma )
32+ ## finish computing log-likelihood
33+ sign_ld , abs_ld = jnp .linalg .slogdet (Sigma )
34+ log_det_sigma = abs_ld * sign_ld ## log-determinant of precision
35+ Z = X - mu ## calc deltas
36+ quad_term = jnp .sum ((jnp .matmul (Z , precision ) * Z ), axis = 1 , keepdims = True ) ## LL quadratic term
37+ return - (jnp .log (2. * np .pi ) * D + log_det_sigma + quad_term ) * 0.5
38+
39+ @partial (jit , static_argnums = [3 ])
40+ def _calc_gaussian_pdf_vals (X , mu , Sigma , use_chol_prec = True ):
41+ log_ll = _log_gaussian_pdf (X , mu , Sigma , use_chol_prec )
42+ ll = jnp .exp (log_ll )
43+ return log_ll , ll
44+
45+ @partial (jit , static_argnums = [3 ])
46+ def _calc_weighted_cov (X , mu , weights , assume_diag_cov = False ): ## M-step co-routine
47+ ## calc new covariance Sigma given data, means, and responsibilities
48+ diff = X - mu
49+ sigma_j = jnp .matmul ((weights * diff ).T , diff ) / jnp .sum (weights )
50+ if assume_diag_cov :
51+ sigma_j = sigma_j * jnp .eye (sigma_j .shape [1 ])
52+ return sigma_j
53+
54+ @jit
55+ def _calc_priors_and_means (X , weights , pi ): ## M-step co-routine
56+ ## calc new means, responsibilities, and priors given current stats
57+ N = X .shape [0 ] ## get number of samples
58+ ## calc responsibilities
59+ r = (pi * weights )
60+ r = r / jnp .sum (r , axis = 1 , keepdims = True ) ## responsibilities
61+ _pi = jnp .sum (r , axis = 0 , keepdims = True ) / N ## calc new priors
62+ ## calc weighted means (weighted by responsibilities)
63+ means = jnp .matmul (r .T , X ) / jnp .sum (r , axis = 0 , keepdims = True ).T
64+ return means , _pi , r
65+
66+ @partial (jit , static_argnums = [1 ])
67+ def _sample_prior_weights (dkey , n_samples , pi ): ## samples prior weighting parameters (of mixture)
68+ log_pi = jnp .log (pi ) ## calc log(prior)
69+ lats = random .categorical (dkey , logits = log_pi , shape = (n_samples , 1 )) ## sample components/latents
70+ return lats
71+
72+ @partial (jit , static_argnums = [1 , 4 ])
73+ def _sample_component (dkey , n_samples , mu , Sigma , assume_diag_cov = False ): ## samples a component (of mixture)
74+ eps = random .normal (dkey , shape = (n_samples , mu .shape [1 ])) ## draw unit Gaussian noise
75+ ## apply scale-shift transformation
76+ if assume_diag_cov :
77+ R = jnp .sum (jnp .sqrt (Sigma ), axis = 0 , keepdims = True )
78+ x_s = mu + eps * R
79+ else :
80+ R = jnp .linalg .cholesky (Sigma ) ## decompose covariance via Cholesky
81+ x_s = mu + jnp .matmul (eps , R ) # tf.matmul(eps, R)
82+ return x_s
83+
84+ # def _log_gaussian_pdf(X, mu, sigma):
85+ # C = jnp.linalg.cholesky(sigma) #calc_prec_chol(mu, cov)
86+ # inv_C = jnp.linalg.pinv(C)
87+ # prec_chol = jnp.matmul(inv_C, inv_C.T)
88+ # #prec_chol = jnp.linalg.inv(sigma)
89+ #
90+ # N, D = X.shape ## n_samples x dimensionality
91+ # # det(precision_chol) is half of det(precision)
92+ # sign_ld, abs_ld = jnp.linalg.slogdet(prec_chol)
93+ # log_det = abs_ld * sign_ld ## log determinant of Cholesky precision
94+ # y = jnp.matmul(X, prec_chol) - jnp.matmul(mu, prec_chol)
95+ # log_prob = jnp.sum(y * y, axis=1, keepdims=True)
96+ # #return -0.5 * (D * jnp.log(np.pi * 2) + log_prob) + log_det
97+ # #return -0.5 * (D * jnp.log(np.pi * 2) + log_det + log_prob)
98+ # return -jnp.log(np.pi * 2) * (D * 0.5) - log_det * 0.5 - log_prob * 0.5
99+
100+ ########################################################################################################################
101+
102+ class GMM : ## Gaussian mixture model (mixture-of-Gaussians)
17103 """
18104 Implements a Gaussian mixture model (GMM) -- or mixture of Gaussians, MoG.
19105 Adaptation of parameters is conducted via the Expectation-Maximization (EM)
@@ -24,59 +110,157 @@ class GMM:
24110 The sampling process has been rewritten to utilize GPU matrix computation.
25111
26112 Args:
27- k : the number of components/latent variables within this GMM
113+ K : the number of components/latent variables within this GMM
28114
29- max_iter: the maximum number of EM iterations to fit parameters to data
30- (Default = 5)
115+ max_iter: the maximum number of EM iterations to fit parameters to data (Default = 50)
31116
32- assume_diag_cov: if True, assumes a diagonal covariance for each component
33- (Default = False)
117+ assume_diag_cov: if True, assumes a diagonal covariance for each component (Default = False)
34118
35- init_kmeans: if True, first learn use the K-Means algorithm to initialize
36- the component Gaussians of this GMM (Default = True)
119+ init_kmeans: <Unsupported>
37120 """
38- def __init__ (self , k , max_iter = 5 , assume_diag_cov = False , init_kmeans = True ):
39- self .use_sklearn = True
40- self .k = k
121+ # init_kmeans: if True, first learn use the K-Means algorithm to initialize
122+ # the component Gaussians of this GMM (Default = False)
123+
124+ def __init__ (self , K , max_iter = 50 , assume_diag_cov = False , init_kmeans = False , key = None ):
125+ self .K = K
41126 self .max_iter = int (max_iter )
42127 self .assume_diag_cov = assume_diag_cov
43- self .init_kmeans = init_kmeans
44- self .mu = []
45- self .sigma = []
46- self .prec = []
47- self .weights = None
48- self .z_weights = None # variables for parameterizing weights for SGD
49-
50- def fit (self , data ):
128+ self .init_kmeans = init_kmeans ## Unsupported currently
129+ self .mu = [] ## component mean parameters
130+ self .Sigma = [] ## component covariance parameters
131+ self .pi = None ## prior weight parameters
132+ #self.z_weights = None # variables for parameterizing weights for SGD
133+ self .key = random .PRNGKey (time .time_ns ()) if key is None else key
134+
135+ def init (self , X ):
136+ """
137+ Initializes this GMM in accordance to a supplied design matrix.
138+
139+ Args:
140+ X: the design matrix to initialize this GMM to
141+
142+ """
143+ dim = X .shape [1 ]
144+ self .key , * skey = random .split (self .key , 3 )
145+ self .pi = jnp .ones ((1 , self .K )) / (self .K * 1. )
146+ ptrs = random .permutation (skey [0 ], X .shape [0 ])
147+ for j in range (self .K ):
148+ ptr = ptrs [j ]
149+ #self.key, *skey = random.split(self.key, 3)
150+ self .mu .append (X [ptr :ptr + 1 ,:])
151+ Sigma_j = jnp .eye (dim )
152+ #sigma_j = random.uniform(skey[0], minval=0.01, maxval=0.9, shape=(dim, dim))
153+ self .Sigma .append (Sigma_j )
154+
155+ def calc_log_likelihood (self , X ):
156+ """
157+ Calculates the multivariate Gaussian log likelihood of a design matrix/dataset `X`, under the current
158+ parameters of this Gaussian mixture model.
159+
160+ Args:
161+ X: the design matrix to estimate log likelihood values over under this GMM
162+
163+ Returns:
164+ (column) vector of individual log likelihoods, scalar for the complete log likelihood p(X)
165+ """
166+ ll = 0.
167+ for j in range (self .K ):
168+ #log_ll_j = log_gaussian_pdf(X, self.mu[j], self.Sigma[j])
169+ #ll_j = jnp.exp(log_ll_j) * self.pi[:,j]
170+ log_ll_j , ll_j = _calc_gaussian_pdf_vals (X , self .mu [j ], self .Sigma [j ])
171+ ll = ll_j + ll
172+ log_ll = jnp .log (ll ) ## vector of individual log p(x_n) values
173+ complete_ll = jnp .sum (log_ll ) ## complete log-likelihood for design matrix X, i.e., log p(X)
174+ return log_ll , complete_ll
175+
176+ def _E_step (self , X ): ## Expectation (E) step, co-routine
177+ weights = []
178+ for j in range (self .K ):
179+ #jax.scipy.stats.multivariate_normal.logpdf(x, mean, cov)
180+ #log_ll_j = log_gaussian_pdf(X, self.mu[j], self.Sigma[j])
181+ log_ll_j , ll_j = _calc_gaussian_pdf_vals (X , self .mu [j ], self .Sigma [j ])
182+ # log_ll_j = scipy.stats.multivariate_normal.logpdf(X, self.mu[j], self.Sigma[j])
183+ # log_ll_j = jnp.expand_dims(log_ll_j, axis=1)
184+ weights .append ( ll_j )
185+ weights = jnp .concat (weights , axis = 1 )
186+ return weights ## data-dependent weights (intermediate responsibilities)
187+
188+ def _M_step (self , X , weights ): ## Maximization (M) step, co-routine
189+ means , pi , r = _calc_priors_and_means (X , weights , self .pi )
190+ self .pi = pi ## store new prior parameters
191+ # calc weighted covariances
192+ for j in range (self .K ):
193+ r_j = r [:, j :j + 1 ]
194+ mu_j = means [j :j + 1 , :]
195+ sigma_j = _calc_weighted_cov (X , mu_j , r_j , assume_diag_cov = self .assume_diag_cov )
196+ self .mu [j ] = mu_j ## store new mean(j) parameter
197+ self .Sigma [j ] = sigma_j ## store new covariance(j) parameter
198+ return means , r
199+
200+ def fit (self , X , tol = 1e-3 ):
51201 """
52202 Run full fitting process of this GMM.
53203
54204 Args:
55- data: the dataset to fit this GMM to
205+ X: the dataset to fit this GMM to
206+
207+ tol: the tolerance value for detecting convergence (via difference-of-means); will engage in early-stopping
208+ if tol >= 0. (Default: 1e-3)
56209 """
57- pass
210+ means_prev = jnp .concat (self .mu , axis = 0 )
211+ for i in range (self .max_iter ):
212+ self .update (X ) ## carry out one E-step followed by an M-step
213+ means = jnp .concat (self .mu , axis = 0 )
214+ #print(jnp.linalg.norm(means - means_prev))
215+ if tol >= 0. and jnp .linalg .norm (means - means_prev ) < tol :
216+ print (f"Converged after { i + 1 } iterations." )
217+ break
218+ means_prev = means
58219
59220 def update (self , X ):
60221 """
61- Performs a single iterative update of parameters (assuming model initialized)
222+ Performs a single iterative update (E-step followed by M-step) of parameters (assuming model initialized)
62223
63224 Args:
64225 X: the dataset / design matrix to fit this GMM to
65226 """
66- pass
227+ r_w = self ._E_step (X ) ## carry out E-step
228+ means , respon = self ._M_step (X , r_w ) ## carry out M-step
67229
68- def sample (self , n_s , mode_i = - 1 , samples_modes_evenly = False ):
230+ def sample (self , n_samples , mode_j = - 1 ):
69231 """
70- (Efficiently) Draw samples from the current underlying GMM model
232+ Draw samples from the current underlying GMM model
71233
72234 Args:
73- n_s : the number of samples to draw from this GMM
235+ n_samples : the number of samples to draw from this GMM
74236
75- mode_i : if >= 0, will only draw samples from a specific component of this GMM
237+ mode_j : if >= 0, will only draw samples from a specific component of this GMM
76238 (Default = -1), ignoring the Categorical prior over latent variables/components
77239
78- samples_modes_evenly: if True, will ignore the Categorical prior over latent
79- variables/components and draw an approximately equal number of samples from
80- each component
240+ Returns:
241+ Design matrix of samples drawn under the distribution defined by this GMM
81242 """
82- pass
243+ ## sample prior
244+ self .key , * skey = random .split (self .key , 3 )
245+ if mode_j >= 0 : ## sample from a particular mode / component
246+ mu_j = self .mu [mode_j ]
247+ Sigma_j = self .Sigma [mode_j ]
248+ Xs = _sample_component (
249+ skey [0 ], n_samples = n_samples , mu = mu_j , Sigma = Sigma_j , assume_diag_cov = self .assume_diag_cov
250+ )
251+ else : ## sample from full mixture distribution
252+ ## sample components/latents
253+ lats = _sample_prior_weights (skey [0 ], n_samples = n_samples , pi = self .pi )
254+ ## then sample chosen component Gaussians
255+ Xs = []
256+ for j in range (self .K ):
257+ freq_j = int (jnp .sum ((lats == j ))) ## compute frequency over mode
258+ print (freq_j )
259+ ## draw unit Gaussian noise
260+ self .key , * skey = random .split (self .key , 3 )
261+ x_s = _sample_component (
262+ skey [0 ], n_samples = freq_j , mu = self .mu [j ], Sigma = self .Sigma [j ], assume_diag_cov = self .assume_diag_cov
263+ )
264+ Xs .append (x_s )
265+ Xs = jnp .concat (Xs , axis = 0 )
266+ return Xs
0 commit comments