Skip to content

Commit 6e6561e

Browse files
author
Alexander Ororbia
committed
cleaned up mixtures and finished debugging EMM/works on example
1 parent 519896e commit 6e6561e

File tree

3 files changed

+157
-160
lines changed

3 files changed

+157
-160
lines changed

ngclearn/utils/density/bernoulliMixture.py

Lines changed: 44 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ def _log_bernoulli_pdf(X, p):
1818
Args:
1919
X: a design matrix (dataset) to compute the log likelihood of
2020
21-
mu: a parameter mean vector
21+
p: a parameter mean vector (positive case probability)
2222
2323
Returns:
2424
the log likelihood (scalar) of this design matrix X
2525
"""
2626
#D = X.shape[1] * 1. ## get dimensionality
27-
## x log(mu_k) + (1-x) log(1 - mu_k)
27+
## general format: x log(mu_k) + (1-x) log(1 - mu_k)
2828
vec_ll = X * jnp.log(p) + (1. - X) * jnp.log(1. - p) ## binary cross-entropy (log Bernoulli)
2929
log_ll = jnp.sum(vec_ll, axis=1, keepdims=True) ## get per-datapoint LL
3030
return log_ll
@@ -35,20 +35,27 @@ def _calc_bernoulli_pdf_vals(X, p):
3535
ll = jnp.exp(log_ll) ## likelihood
3636
return log_ll, ll
3737

38+
@jit
39+
def _calc_bernoulli_mixture_stats(raw_likeli, pi):
40+
likeli = raw_likeli * pi
41+
gamma = likeli / jnp.sum(likeli, axis=1, keepdims=True) ## responsibilities
42+
likeli = jnp.sum(likeli, axis=1, keepdims=True) ## Sum_j[ pi_j * pdf_gauss(x_n; mu_j, Sigma_j) ]
43+
log_likeli = jnp.log(likeli) ## vector of individual log p(x_n) values
44+
complete_log_likeli = jnp.sum(log_likeli) ## complete log-likelihood for design matrix X, i.e., log p(X)
45+
return log_likeli, complete_log_likeli, gamma
46+
3847
@jit
3948
def _calc_priors_and_means(X, weights, pi): ## M-step co-routine
4049
## calc new means, responsibilities, and priors given current stats
4150
N = X.shape[0] ## get number of samples
4251
## calc responsibilities
43-
r = (pi * weights)
44-
r = r / jnp.sum(r, axis=1, keepdims=True) ## responsibilities
45-
_pi = jnp.sum(r, axis=0, keepdims=True) / N ## calc new priors
52+
_pi = jnp.sum(weights, axis=0, keepdims=True) / N ## calc new priors
4653
## calc weighted means (weighted by responsibilities)
47-
Z = jnp.sum(r, axis=0, keepdims=True) ## partition function
54+
Z = jnp.sum(weights, axis=0, keepdims=True) ## partition function
4855
M = (Z > 0.) * 1.
4956
Z = Z * M + (1. + M) ## removes div-by-0 cases
50-
means = jnp.matmul(r.T, X) / Z.T
51-
return means, _pi, r
57+
means = jnp.matmul(weights.T, X) / Z.T
58+
return _pi, means
5259

5360
@partial(jit, static_argnums=[1])
5461
def _sample_prior_weights(dkey, n_samples, pi): ## samples prior weighting parameters (of mixture)
@@ -58,7 +65,7 @@ def _sample_prior_weights(dkey, n_samples, pi): ## samples prior weighting param
5865

5966
@partial(jit, static_argnums=[1])
6067
def _sample_component(dkey, n_samples, mu): ## samples a component (of mixture)
61-
eps = random.bernoulli(dkey, p=mu, shape=(n_samples, mu.shape[1])) ## draw Bernoulli samples
68+
x_s = random.bernoulli(dkey, p=mu, shape=(n_samples, mu.shape[1])) ## draw Bernoulli samples
6269
return x_s
6370

6471
########################################################################################################################
@@ -119,31 +126,32 @@ def calc_log_likelihood(self, X):
119126
Returns:
120127
(column) vector of individual log likelihoods, scalar for the complete log likelihood p(X)
121128
"""
122-
ll = 0.
129+
likeli = []
123130
for j in range(self.K):
124-
log_ll_j, ll_j = _calc_bernoulli_pdf_vals(X, self.mu[j])
125-
ll = ll_j + ll
126-
log_ll = jnp.log(ll) ## vector of individual log p(x_n) values
127-
complete_ll = jnp.sum(log_ll) ## complete log-likelihood for design matrix X, i.e., log p(X)
128-
return log_ll, complete_ll
131+
_, likeli_j = _calc_bernoulli_pdf_vals(X, self.mu[j])
132+
likeli.append(likeli_j)
133+
likeli = jnp.concat(likeli, axis=1)
134+
log_likeli_vec, complete_log_likeli, gamma = _calc_bernoulli_mixture_stats(likeli, self.pi)
135+
return log_likeli_vec, complete_log_likeli
129136

130137
def _E_step(self, X): ## Expectation (E) step, co-routine
131-
weights = []
138+
likeli = []
132139
for j in range(self.K):
133-
log_ll_j, ll_j = _calc_bernoulli_pdf_vals(X, self.mu[j])
134-
weights.append( ll_j )
135-
weights = jnp.concat(weights, axis=1)
136-
return weights ## data-dependent weights (intermediate responsibilities)
140+
_, likeli_j = _calc_bernoulli_pdf_vals(X, self.mu[j])
141+
likeli.append(likeli_j)
142+
likeli = jnp.concat(likeli, axis=1)
143+
log_likeli_vec, complete_log_likeli, gamma = _calc_bernoulli_mixture_stats(likeli, self.pi)
144+
## gamma => ## data-dependent weights (responsibilities)
145+
return gamma, log_likeli_vec, complete_log_likeli
137146

138147
def _M_step(self, X, weights): ## Maximization (M) step, co-routine
139-
means, pi, r = _calc_priors_and_means(X, weights, self.pi)
140-
self.pi = pi ## store new prior parameters
141-
# calc weighted covariances
148+
pi, means = _calc_priors_and_means(X, weights, self.pi)
149+
self.pi = pi ## store new prior parameters
142150
for j in range(self.K):
143-
#r_j = r[:, j:j + 1]
151+
#r_j = weights[:, j:j + 1] ## get j-th responsibility slice
144152
mu_j = means[j:j + 1, :]
145-
self.mu[j] = mu_j ## store new mean(j) parameter
146-
return means, r
153+
self.mu[j] = mu_j ## store new mean(j) parameter
154+
return pi, means
147155

148156
def fit(self, X, tol=1e-3, verbose=False):
149157
"""
@@ -159,11 +167,11 @@ def fit(self, X, tol=1e-3, verbose=False):
159167
"""
160168
means_prev = jnp.concat(self.mu, axis=0)
161169
for i in range(self.max_iter):
162-
self.update(X) ## carry out one E-step followed by an M-step
163-
means = jnp.concat(self.mu, axis=0)
170+
gamma, pi, means, complete_loglikeli = self.update(X) ## carry out one E-step followed by an M-step
171+
#means = jnp.concat(self.mu, axis=0)
164172
dom = jnp.linalg.norm(means - means_prev) ## norm of difference-of-means
165173
if verbose:
166-
print(f"{i}: Mean-diff = {dom}")
174+
print(f"{i}: Mean-diff = {dom} log(p(X)) = {complete_loglikeli} nats")
167175
#print(jnp.linalg.norm(means - means_prev))
168176
if tol >= 0. and dom < tol:
169177
print(f"Converged after {i + 1} iterations.")
@@ -177,8 +185,9 @@ def update(self, X):
177185
Args:
178186
X: the dataset / design matrix to fit this BMM to
179187
"""
180-
r_w = self._E_step(X) ## carry out E-step
181-
means, respon = self._M_step(X, r_w) ## carry out M-step
188+
gamma, _, complete_likeli = self._E_step(X) ## carry out E-step
189+
pi, means = self._M_step(X, gamma) ## carry out M-step
190+
return gamma, pi, means, complete_likeli
182191

183192
def sample(self, n_samples, mode_j=-1):
184193
"""
@@ -193,15 +202,14 @@ def sample(self, n_samples, mode_j=-1):
193202
Returns:
194203
Design matrix of samples drawn under the distribution defined by this BMM
195204
"""
196-
## sample prior
197205
self.key, *skey = random.split(self.key, 3)
198-
if mode_j >= 0: ## sample from a particular mode / component
199-
mu_j = self.mu[mode_j]
206+
if mode_j >= 0: ## sample from a particular mode
207+
mu_j = self.mu[mode_j] ## directly select a specific component
200208
Xs = _sample_component(skey[0], n_samples=n_samples, mu=mu_j)
201209
else: ## sample from full mixture distribution
202-
## sample components/latents
210+
## sample (prior) components/latents
203211
lats = _sample_prior_weights(skey[0], n_samples=n_samples, pi=self.pi)
204-
## then sample chosen component Bernoulli
212+
## then sample chosen component Bernoulli(s)
205213
Xs = []
206214
for j in range(self.K):
207215
freq_j = int(jnp.sum((lats == j))) ## compute frequency over mode

ngclearn/utils/density/exponentialMixture.py

Lines changed: 71 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,49 @@
11
from jax import numpy as jnp, random, jit, scipy
22
from functools import partial
33
import time, sys
4-
import numpy as np
54

65
from ngclearn.utils.density.mixture import Mixture
76

87
########################################################################################################################
98
## internal routines for mixture model
109
########################################################################################################################
11-
1210
@jit
13-
def _log_exponential_pdf(X, rate):
11+
def _log_exponential_pdf(X, lmbda):
1412
"""
15-
Calculates the multivariate exponential log likelihood of a design matrix/dataset `X`, under a given parameter
13+
Calculates the multivariate exponential log likelihood of a design matrix/dataset `X`, under a given parameter
1614
probability `p`.
1715
1816
Args:
1917
X: a design matrix (dataset) to compute the log likelihood of
2018
21-
rate: a parameter rate vector
19+
lmbda: a parameter rate vector
2220
2321
Returns:
2422
the log likelihood (scalar) of this design matrix X
2523
"""
26-
#D = X.shape[1] * 1. ## get dimensionality
27-
## pdf(x; r) = r * np.exp(-r * x), where r is "rate"
28-
## log (r exp(-r x) ) = log(r) + log(exp(-r x) = log(r) - r x
29-
vec_ll = -(X * rate) + jnp.log(rate) ## log exponential
30-
log_ll = jnp.sum(vec_ll, axis=1, keepdims=True) ## get per-datapoint LL
31-
return log_ll
24+
log_pdf = -jnp.matmul(X, lmbda.T) + jnp.sum(jnp.log(lmbda.T), axis=0)
25+
return log_pdf
3226

3327
@jit
34-
def _calc_exponential_pdf_vals(X, p):
35-
log_ll = _log_exponential_pdf(X, p) ## get log-likelihood
36-
ll = jnp.exp(log_ll) ## likelihood
37-
return log_ll, ll
38-
39-
#@jit
40-
def _calc_priors_and_rates(X, weights, pi): ## M-step co-routine
41-
## calc new rates, responsibilities, and priors given current stats
42-
N = X.shape[0] ## get number of samples
43-
## calc responsibilities
44-
r = (pi * weights)
45-
r = r / jnp.sum(r, axis=1, keepdims=True) ## responsibilities
46-
_pi = jnp.sum(r, axis=0, keepdims=True) / N ## calc new priors
47-
## calc weighted rates (weighted by responsibilities)
48-
49-
Znum = jnp.sum(r, axis=0, keepdims=True)
50-
#print(Znum.shape)
51-
Zden = jnp.matmul(r.T, X)
52-
rates = Znum.T/Zden
53-
#print(Zden.shape)
54-
#exit()
55-
"""
56-
Z = jnp.sum(r, axis=0, keepdims=True) ## calc partition function
57-
Ndata = jnp.matmul(r.T, X)
58-
M = (Ndata > 0.) * 1.
59-
Ndata = Ndata * M + (1. - M) ## we mask out division-by-0 cases
60-
rates = Z.T / Ndata
61-
"""
62-
return rates, _pi, r
28+
def _calc_exponential_mixture_stats(X, lmbda, pi):
29+
log_exp_pdf = _log_exponential_pdf(X, lmbda)
30+
log_likeli = log_exp_pdf + jnp.log(pi) ## raw log-likelihood
31+
likeli = jnp.exp(log_likeli) ## raw likelihood
32+
gamma = likeli / jnp.sum(likeli, axis=1, keepdims=True) ## responsibilities
33+
weighted_log_likeli = jnp.sum(log_likeli * gamma, axis=1, keepdims=True) ## get weighted EMM log-likelihood
34+
complete_loglikeli = jnp.sum(weighted_log_likeli) ## complete log-likelihood for design matrix X, i.e., log p(X)
35+
return log_likeli, likeli, gamma, weighted_log_likeli, complete_loglikeli
36+
37+
@jit
38+
def _calc_priors_and_rates(X, weights, pi): ## M-step co-routine
39+
## compute updates to pi params
40+
Zk = jnp.sum(weights, axis=0, keepdims=True) ## summed weights/responsibilities; 1 x K
41+
Z = jnp.sum(Zk) ## partition function
42+
pi = Zk / Z
43+
## compute updates to lmbda params
44+
Z = jnp.matmul(weights.T, X)
45+
lmbda = Zk.T / Z
46+
return pi, lmbda
6347

6448
@partial(jit, static_argnums=[1])
6549
def _sample_prior_weights(dkey, n_samples, pi): ## samples prior weighting parameters (of mixture)
@@ -70,18 +54,17 @@ def _sample_prior_weights(dkey, n_samples, pi): ## samples prior weighting param
7054
@partial(jit, static_argnums=[1])
7155
def _sample_component(dkey, n_samples, rate): ## samples a component (of mixture)
7256
## sampling ~[exp(rx)] is same as r * [~exp(x)]
73-
eps = jax.random.exponential(dkey, shape=(n_samples, mu.shape[1])) * rate ## draw exponential samples
57+
x_s = random.exponential(dkey, shape=(n_samples, rate.shape[1])) * rate ## draw exponential samples
7458
return x_s
7559

7660
########################################################################################################################
7761

7862
class ExponentialMixture(Mixture): ## Exponential mixture model (mixture-of-exponentials)
7963
"""
80-
Implements a exponential mixture model (EMM) -- or mixture of exponentials (MoExp).
81-
Adaptation of parameters is conducted via the Expectation-Maximization (EM)
82-
learning algorithm. Note that this exponential mixture assumes that each component
83-
is a factorizable mutlivariate exponential distribution. (A Categorical distribution
84-
is assumed over the latent variables).
64+
Implements a exponential mixture model (EMM) -- or mixture of exponentials (MoExp). Adaptation of parameters is
65+
conducted via the Expectation-Maximization (EM) learning algorithm. Note that this exponential mixture assumes that
66+
each component is a factorizable mutlivariate exponential distribution. (A Categorical distribution is assumed over
67+
the latent variables).
8568
8669
Args:
8770
K: the number of components/latent variables within this EMM
@@ -110,15 +93,20 @@ def init(self, X):
11093
11194
"""
11295
dim = X.shape[1]
113-
self.key, *skey = random.split(self.key, 3)
114-
self.pi = jnp.ones((1, self.K)) / (self.K * 1.)
115-
ptrs = random.permutation(skey[0], X.shape[0])
96+
self.key, *skey = random.split(self.key, 4)
97+
## Computed jittered initial phi param values
98+
#self.pi = jnp.ones((1, self.K)) / (self.K * 1.)
99+
pi = jnp.ones((1, self.K))
100+
eps = random.uniform(skey[0], minval=0.99, maxval=1.01, shape=(1, self.K))
101+
pi = pi * eps
102+
self.pi = pi / jnp.sum(pi)
103+
104+
## Computed jittered initial rate (lmbda) param values
105+
lmbda_h = 1.0/jnp.mean(X, axis=0, keepdims=True)
106+
lmbda = random.uniform(skey[1], minval=0.99, maxval=1.01, shape=(self.K, dim)) * lmbda_h
116107
self.rate = []
117-
for j in range(self.K):
118-
ptr = ptrs[j]
119-
self.key, *skey = random.split(self.key, 3)
120-
eps = random.uniform(skey[0], minval=0.99, maxval=1.01, shape=(1, dim)) ## jitter initial rate params
121-
self.rate.append(eps)
108+
for j in range(self.K): ## set rates/lmbdas
109+
self.rate.append(lmbda[j:j+1, :])
122110

123111
def calc_log_likelihood(self, X):
124112
"""
@@ -131,31 +119,26 @@ def calc_log_likelihood(self, X):
131119
Returns:
132120
(column) vector of individual log likelihoods, scalar for the complete log likelihood p(X)
133121
"""
134-
ll = 0.
135-
for j in range(self.K):
136-
log_ll_j, ll_j = _calc_exponential_pdf_vals(X, self.rate[j])
137-
ll = ll_j + ll
138-
log_ll = jnp.log(ll) ## vector of individual log p(x_n) values
139-
complete_ll = jnp.sum(log_ll) ## complete log-likelihood for design matrix X, i.e., log p(X)
140-
return log_ll, complete_ll
122+
pi = self.pi ## get prior weight values
123+
lmbda = jnp.concat(self.rate, axis=0) ## get rates as a block matrix
124+
## compute relevant log-likelihoods/likelihoods
125+
log_ll, ll, gamma, weighted_loglikeli, complete_likeli = _calc_exponential_mixture_stats(X, lmbda, pi)
126+
return weighted_loglikeli, complete_likeli
141127

142128
def _E_step(self, X): ## Expectation (E) step, co-routine
143-
weights = []
144-
for j in range(self.K):
145-
log_ll_j, ll_j = _calc_exponential_pdf_vals(X, self.rate[j])
146-
weights.append( ll_j )
147-
weights = jnp.concat(weights, axis=1)
148-
return weights ## data-dependent weights (intermediate responsibilities)
129+
pi = self.pi ## get prior weight values
130+
lmbda = jnp.concat(self.rate, axis=0) ## get rates as a block matrix
131+
_, _, gamma, weighted_loglikeli, complete_likeli = _calc_exponential_mixture_stats(X, lmbda, pi)
132+
## Note: responsibility weights gamma have shape => N x K
133+
return gamma, weighted_loglikeli, complete_likeli
149134

150135
def _M_step(self, X, weights): ## Maximization (M) step, co-routine
151-
rates, pi, r = _calc_priors_and_rates(X, weights, self.pi)
152-
self.pi = pi ## store new prior parameters
153-
# calc weighted covariances
154-
for j in range(self.K):
155-
#r_j = r[:, j:j + 1]
156-
rate_j = rates[j:j + 1, :]
157-
self.rate[j] = rate_j ## store new rate(j) parameter
158-
return rates, r
136+
## compute updates to pi and lmbda params
137+
pi, lmbda = _calc_priors_and_rates(X, weights, self.pi)
138+
self.pi = pi ## store new prior parameters
139+
for j in range(self.K): ## store new rate/lmbda parameters
140+
self.rate[j] = lmbda[j:j+1, :]
141+
return pi, lmbda
159142

160143
def fit(self, X, tol=1e-3, verbose=False):
161144
"""
@@ -171,11 +154,11 @@ def fit(self, X, tol=1e-3, verbose=False):
171154
"""
172155
rates_prev = jnp.concat(self.rate, axis=0)
173156
for i in range(self.max_iter):
174-
self.update(X) ## carry out one E-step followed by an M-step
175-
rates = jnp.concat(self.rate, axis=0)
157+
gamma, pi, rates, complete_loglikeli = self.update(X) ## carry out one E-step followed by an M-step
158+
#rates = jnp.concat(self.rate, axis=0)
176159
dor = jnp.linalg.norm(rates - rates_prev) ## norm of difference-of-rates
177160
if verbose:
178-
print(f"{i}: Rate-diff = {dor}")
161+
print(f"{i}: Rate-diff = {dor} log(p(X)) = {complete_loglikeli} nats")
179162
#print(jnp.linalg.norm(rates - rates_prev))
180163
if tol >= 0. and dor < tol:
181164
print(f"Converged after {i + 1} iterations.")
@@ -188,9 +171,13 @@ def update(self, X):
188171
189172
Args:
190173
X: the dataset / design matrix to fit this BMM to
174+
175+
Returns:
176+
responsibilities (gamma), priors (pi), rates (lambda), EMM log-likelihood
191177
"""
192-
r_w = self._E_step(X) ## carry out E-step
193-
rates, respon = self._M_step(X, r_w) ## carry out M-step
178+
gamma, _, complete_log_likeli = self._E_step(X) ## carry out E-step
179+
pi, rates = self._M_step(X, gamma) ## carry out M-step
180+
return gamma, pi, rates, complete_log_likeli
194181

195182
def sample(self, n_samples, mode_j=-1):
196183
"""
@@ -205,15 +192,14 @@ def sample(self, n_samples, mode_j=-1):
205192
Returns:
206193
Design matrix of samples drawn under the distribution defined by this EMM
207194
"""
208-
## sample prior
209195
self.key, *skey = random.split(self.key, 3)
210-
if mode_j >= 0: ## sample from a particular mode / component
211-
rate_j = self.rate[mode_j]
196+
if mode_j >= 0: ## sample from a particular mode
197+
rate_j = self.rate[mode_j] ## directly select a specific component
212198
Xs = _sample_component(skey[0], n_samples=n_samples, rate=rate_j)
213199
else: ## sample from full mixture distribution
214-
## sample components/latents
200+
## sample (prior) components/latents
215201
lats = _sample_prior_weights(skey[0], n_samples=n_samples, pi=self.pi)
216-
## then sample chosen component exponential
202+
## then sample chosen component exponential(s)
217203
Xs = []
218204
for j in range(self.K):
219205
freq_j = int(jnp.sum((lats == j))) ## compute frequency over mode

0 commit comments

Comments
 (0)