Skip to content

Commit 85d1282

Browse files
author
Alexander Ororbia
committed
implemented in-house gmm, in-built to ngclearn; tested on gaussian mode data
1 parent e3b8ef6 commit 85d1282

File tree

1 file changed

+227
-43
lines changed

1 file changed

+227
-43
lines changed

ngclearn/utils/density/gmm.py

Lines changed: 227 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,105 @@
1-
from jax import numpy as jnp, random, jit
1+
from jax import numpy as jnp, random, jit, scipy
22
from functools import partial
33
import time, sys
44
import numpy as np
5-
#from sklearn import mixture
6-
#from sklearn.cluster import KMeans
7-
from scipy.stats import multivariate_normal
8-
#from ngclearn.utils.stat_utils import calc_log_gauss_pdf
9-
from ngclearn.utils.model_utils import softmax
10-
#from kmeans import K_Means
11-
from sklearn import mixture
12-
13-
#seed = 69
14-
#tf.random.set_seed(seed=seed)
15-
16-
class GMM:
5+
6+
########################################################################################################################
7+
## internal routines for mixture model
8+
########################################################################################################################
9+
10+
@partial(jit, static_argnums=[3])
11+
def _log_gaussian_pdf(X, mu, Sigma, use_chol_prec=True):
12+
"""
13+
Calculates the multivariate Gaussian log likelihood of a design matrix/dataset `X`, under a given parameter mean
14+
`mu` and parameter covariance `Sigma`.
15+
16+
Args:
17+
X: a design matrix (dataset) to compute the log likelihood of
18+
mu: a parameter mean vector
19+
Sigma: a parameter covariance matrix
20+
use_chol_prec: should this routine use Cholesky-factor computation of the precision (Default: True)
21+
22+
Returns:
23+
the log likelihood (scalar) of this design matrix X
24+
"""
25+
D = mu.shape[1] * 1. ## get dimensionality
26+
if use_chol_prec: ## use Cholesky-factor calc of precision
27+
C = jnp.linalg.cholesky(Sigma) # calc_prec_chol(mu, cov)
28+
inv_C = jnp.linalg.pinv(C)
29+
precision = jnp.matmul(inv_C.T, inv_C)
30+
else: ## use Moore-Penrose pseudo-inverse calc of precision
31+
precision = jnp.linalg.pinv(Sigma)
32+
## finish computing log-likelihood
33+
sign_ld, abs_ld = jnp.linalg.slogdet(Sigma)
34+
log_det_sigma = abs_ld * sign_ld ## log-determinant of precision
35+
Z = X - mu ## calc deltas
36+
quad_term = jnp.sum((jnp.matmul(Z, precision) * Z), axis=1, keepdims=True) ## LL quadratic term
37+
return -(jnp.log(2. * np.pi) * D + log_det_sigma + quad_term) * 0.5
38+
39+
@partial(jit, static_argnums=[3])
40+
def _calc_gaussian_pdf_vals(X, mu, Sigma, use_chol_prec=True):
41+
log_ll = _log_gaussian_pdf(X, mu, Sigma, use_chol_prec)
42+
ll = jnp.exp(log_ll)
43+
return log_ll, ll
44+
45+
@partial(jit, static_argnums=[3])
46+
def _calc_weighted_cov(X, mu, weights, assume_diag_cov=False): ## M-step co-routine
47+
## calc new covariance Sigma given data, means, and responsibilities
48+
diff = X - mu
49+
sigma_j = jnp.matmul((weights * diff).T, diff) / jnp.sum(weights)
50+
if assume_diag_cov:
51+
sigma_j = sigma_j * jnp.eye(sigma_j.shape[1])
52+
return sigma_j
53+
54+
@jit
55+
def _calc_priors_and_means(X, weights, pi): ## M-step co-routine
56+
## calc new means, responsibilities, and priors given current stats
57+
N = X.shape[0] ## get number of samples
58+
## calc responsibilities
59+
r = (pi * weights)
60+
r = r / jnp.sum(r, axis=1, keepdims=True) ## responsibilities
61+
_pi = jnp.sum(r, axis=0, keepdims=True) / N ## calc new priors
62+
## calc weighted means (weighted by responsibilities)
63+
means = jnp.matmul(r.T, X) / jnp.sum(r, axis=0, keepdims=True).T
64+
return means, _pi, r
65+
66+
@partial(jit, static_argnums=[1])
67+
def _sample_prior_weights(dkey, n_samples, pi): ## samples prior weighting parameters (of mixture)
68+
log_pi = jnp.log(pi) ## calc log(prior)
69+
lats = random.categorical(dkey, logits=log_pi, shape=(n_samples, 1)) ## sample components/latents
70+
return lats
71+
72+
@partial(jit, static_argnums=[1, 4])
73+
def _sample_component(dkey, n_samples, mu, Sigma, assume_diag_cov=False): ## samples a component (of mixture)
74+
eps = random.normal(dkey, shape=(n_samples, mu.shape[1])) ## draw unit Gaussian noise
75+
## apply scale-shift transformation
76+
if assume_diag_cov:
77+
R = jnp.sum(jnp.sqrt(Sigma), axis=0, keepdims=True)
78+
x_s = mu + eps * R
79+
else:
80+
R = jnp.linalg.cholesky(Sigma) ## decompose covariance via Cholesky
81+
x_s = mu + jnp.matmul(eps, R) # tf.matmul(eps, R)
82+
return x_s
83+
84+
# def _log_gaussian_pdf(X, mu, sigma):
85+
# C = jnp.linalg.cholesky(sigma) #calc_prec_chol(mu, cov)
86+
# inv_C = jnp.linalg.pinv(C)
87+
# prec_chol = jnp.matmul(inv_C, inv_C.T)
88+
# #prec_chol = jnp.linalg.inv(sigma)
89+
#
90+
# N, D = X.shape ## n_samples x dimensionality
91+
# # det(precision_chol) is half of det(precision)
92+
# sign_ld, abs_ld = jnp.linalg.slogdet(prec_chol)
93+
# log_det = abs_ld * sign_ld ## log determinant of Cholesky precision
94+
# y = jnp.matmul(X, prec_chol) - jnp.matmul(mu, prec_chol)
95+
# log_prob = jnp.sum(y * y, axis=1, keepdims=True)
96+
# #return -0.5 * (D * jnp.log(np.pi * 2) + log_prob) + log_det
97+
# #return -0.5 * (D * jnp.log(np.pi * 2) + log_det + log_prob)
98+
# return -jnp.log(np.pi * 2) * (D * 0.5) - log_det * 0.5 - log_prob * 0.5
99+
100+
########################################################################################################################
101+
102+
class GMM: ## Gaussian mixture model (mixture-of-Gaussians)
17103
"""
18104
Implements a Gaussian mixture model (GMM) -- or mixture of Gaussians, MoG.
19105
Adaptation of parameters is conducted via the Expectation-Maximization (EM)
@@ -24,59 +110,157 @@ class GMM:
24110
The sampling process has been rewritten to utilize GPU matrix computation.
25111
26112
Args:
27-
k: the number of components/latent variables within this GMM
113+
K: the number of components/latent variables within this GMM
28114
29-
max_iter: the maximum number of EM iterations to fit parameters to data
30-
(Default = 5)
115+
max_iter: the maximum number of EM iterations to fit parameters to data (Default = 50)
31116
32-
assume_diag_cov: if True, assumes a diagonal covariance for each component
33-
(Default = False)
117+
assume_diag_cov: if True, assumes a diagonal covariance for each component (Default = False)
34118
35-
init_kmeans: if True, first learn use the K-Means algorithm to initialize
36-
the component Gaussians of this GMM (Default = True)
119+
init_kmeans: <Unsupported>
37120
"""
38-
def __init__(self, k, max_iter=5, assume_diag_cov=False, init_kmeans=True):
39-
self.use_sklearn = True
40-
self.k = k
121+
# init_kmeans: if True, first learn use the K-Means algorithm to initialize
122+
# the component Gaussians of this GMM (Default = False)
123+
124+
def __init__(self, K, max_iter=50, assume_diag_cov=False, init_kmeans=False, key=None):
125+
self.K = K
41126
self.max_iter = int(max_iter)
42127
self.assume_diag_cov = assume_diag_cov
43-
self.init_kmeans = init_kmeans
44-
self.mu = []
45-
self.sigma = []
46-
self.prec = []
47-
self.weights = None
48-
self.z_weights = None # variables for parameterizing weights for SGD
49-
50-
def fit(self, data):
128+
self.init_kmeans = init_kmeans ## Unsupported currently
129+
self.mu = [] ## component mean parameters
130+
self.Sigma = [] ## component covariance parameters
131+
self.pi = None ## prior weight parameters
132+
#self.z_weights = None # variables for parameterizing weights for SGD
133+
self.key = random.PRNGKey(time.time_ns()) if key is None else key
134+
135+
def init(self, X):
136+
"""
137+
Initializes this GMM in accordance to a supplied design matrix.
138+
139+
Args:
140+
X: the design matrix to initialize this GMM to
141+
142+
"""
143+
dim = X.shape[1]
144+
self.key, *skey = random.split(self.key, 3)
145+
self.pi = jnp.ones((1, self.K)) / (self.K * 1.)
146+
ptrs = random.permutation(skey[0], X.shape[0])
147+
for j in range(self.K):
148+
ptr = ptrs[j]
149+
#self.key, *skey = random.split(self.key, 3)
150+
self.mu.append(X[ptr:ptr+1,:])
151+
Sigma_j = jnp.eye(dim)
152+
#sigma_j = random.uniform(skey[0], minval=0.01, maxval=0.9, shape=(dim, dim))
153+
self.Sigma.append(Sigma_j)
154+
155+
def calc_log_likelihood(self, X):
156+
"""
157+
Calculates the multivariate Gaussian log likelihood of a design matrix/dataset `X`, under the current
158+
parameters of this Gaussian mixture model.
159+
160+
Args:
161+
X: the design matrix to estimate log likelihood values over under this GMM
162+
163+
Returns:
164+
(column) vector of individual log likelihoods, scalar for the complete log likelihood p(X)
165+
"""
166+
ll = 0.
167+
for j in range(self.K):
168+
#log_ll_j = log_gaussian_pdf(X, self.mu[j], self.Sigma[j])
169+
#ll_j = jnp.exp(log_ll_j) * self.pi[:,j]
170+
log_ll_j, ll_j = _calc_gaussian_pdf_vals(X, self.mu[j], self.Sigma[j])
171+
ll = ll_j + ll
172+
log_ll = jnp.log(ll) ## vector of individual log p(x_n) values
173+
complete_ll = jnp.sum(log_ll) ## complete log-likelihood for design matrix X, i.e., log p(X)
174+
return log_ll, complete_ll
175+
176+
def _E_step(self, X): ## Expectation (E) step, co-routine
177+
weights = []
178+
for j in range(self.K):
179+
#jax.scipy.stats.multivariate_normal.logpdf(x, mean, cov)
180+
#log_ll_j = log_gaussian_pdf(X, self.mu[j], self.Sigma[j])
181+
log_ll_j, ll_j = _calc_gaussian_pdf_vals(X, self.mu[j], self.Sigma[j])
182+
# log_ll_j = scipy.stats.multivariate_normal.logpdf(X, self.mu[j], self.Sigma[j])
183+
# log_ll_j = jnp.expand_dims(log_ll_j, axis=1)
184+
weights.append( ll_j )
185+
weights = jnp.concat(weights, axis=1)
186+
return weights ## data-dependent weights (intermediate responsibilities)
187+
188+
def _M_step(self, X, weights): ## Maximization (M) step, co-routine
189+
means, pi, r = _calc_priors_and_means(X, weights, self.pi)
190+
self.pi = pi ## store new prior parameters
191+
# calc weighted covariances
192+
for j in range(self.K):
193+
r_j = r[:, j:j + 1]
194+
mu_j = means[j:j + 1, :]
195+
sigma_j = _calc_weighted_cov(X, mu_j, r_j, assume_diag_cov=self.assume_diag_cov)
196+
self.mu[j] = mu_j ## store new mean(j) parameter
197+
self.Sigma[j] = sigma_j ## store new covariance(j) parameter
198+
return means, r
199+
200+
def fit(self, X, tol=1e-3):
51201
"""
52202
Run full fitting process of this GMM.
53203
54204
Args:
55-
data: the dataset to fit this GMM to
205+
X: the dataset to fit this GMM to
206+
207+
tol: the tolerance value for detecting convergence (via difference-of-means); will engage in early-stopping
208+
if tol >= 0. (Default: 1e-3)
56209
"""
57-
pass
210+
means_prev = jnp.concat(self.mu, axis=0)
211+
for i in range(self.max_iter):
212+
self.update(X) ## carry out one E-step followed by an M-step
213+
means = jnp.concat(self.mu, axis=0)
214+
#print(jnp.linalg.norm(means - means_prev))
215+
if tol >= 0. and jnp.linalg.norm(means - means_prev) < tol:
216+
print(f"Converged after {i + 1} iterations.")
217+
break
218+
means_prev = means
58219

59220
def update(self, X):
60221
"""
61-
Performs a single iterative update of parameters (assuming model initialized)
222+
Performs a single iterative update (E-step followed by M-step) of parameters (assuming model initialized)
62223
63224
Args:
64225
X: the dataset / design matrix to fit this GMM to
65226
"""
66-
pass
227+
r_w = self._E_step(X) ## carry out E-step
228+
means, respon = self._M_step(X, r_w) ## carry out M-step
67229

68-
def sample(self, n_s, mode_i=-1, samples_modes_evenly=False):
230+
def sample(self, n_samples, mode_j=-1):
69231
"""
70-
(Efficiently) Draw samples from the current underlying GMM model
232+
Draw samples from the current underlying GMM model
71233
72234
Args:
73-
n_s: the number of samples to draw from this GMM
235+
n_samples: the number of samples to draw from this GMM
74236
75-
mode_i: if >= 0, will only draw samples from a specific component of this GMM
237+
mode_j: if >= 0, will only draw samples from a specific component of this GMM
76238
(Default = -1), ignoring the Categorical prior over latent variables/components
77239
78-
samples_modes_evenly: if True, will ignore the Categorical prior over latent
79-
variables/components and draw an approximately equal number of samples from
80-
each component
240+
Returns:
241+
Design matrix of samples drawn under the distribution defined by this GMM
81242
"""
82-
pass
243+
## sample prior
244+
self.key, *skey = random.split(self.key, 3)
245+
if mode_j >= 0: ## sample from a particular mode / component
246+
mu_j = self.mu[mode_j]
247+
Sigma_j = self.Sigma[mode_j]
248+
Xs = _sample_component(
249+
skey[0], n_samples=n_samples, mu=mu_j, Sigma=Sigma_j, assume_diag_cov=self.assume_diag_cov
250+
)
251+
else: ## sample from full mixture distribution
252+
## sample components/latents
253+
lats = _sample_prior_weights(skey[0], n_samples=n_samples, pi=self.pi)
254+
## then sample chosen component Gaussians
255+
Xs = []
256+
for j in range(self.K):
257+
freq_j = int(jnp.sum((lats == j))) ## compute frequency over mode
258+
print(freq_j)
259+
## draw unit Gaussian noise
260+
self.key, *skey = random.split(self.key, 3)
261+
x_s = _sample_component(
262+
skey[0], n_samples=freq_j, mu=self.mu[j], Sigma=self.Sigma[j], assume_diag_cov=self.assume_diag_cov
263+
)
264+
Xs.append(x_s)
265+
Xs = jnp.concat(Xs, axis=0)
266+
return Xs

0 commit comments

Comments
 (0)