Skip to content

Commit 02e906e

Browse files
author
Alexander Ororbia
committed
added basic exp-mixture to utils.density
1 parent 5041440 commit 02e906e

File tree

5 files changed

+234
-4
lines changed

5 files changed

+234
-4
lines changed

docs/source/ngclearn.utils.density.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,14 @@ ngclearn.utils.density.bernoulliMixture module
1212
:undoc-members:
1313
:show-inheritance:
1414

15+
ngclearn.utils.density.exponentialMixture module
16+
------------------------------------------------
17+
18+
.. automodule:: ngclearn.utils.density.exponentialMixture
19+
:members:
20+
:undoc-members:
21+
:show-inheritance:
22+
1523
ngclearn.utils.density.gaussianMixture module
1624
---------------------------------------------
1725

ngclearn/utils/density/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .mixture import Mixture ## general mixture template parent class
22
## point to supported density estimator models
3-
from .gaussianMixture import GaussianMixture ## Gaussian mixture model
4-
from .bernoulliMixture import BernoulliMixture ## Bernoulli mixture model
3+
from .gaussianMixture import GaussianMixture ## mixture-of-Gaussians
4+
from .bernoulliMixture import BernoulliMixture ## mixture-of-Bernoullis
5+
from .exponentialMixture import ExponentialMixture ## mixture-of-exponentials
56

ngclearn/utils/density/bernoulliMixture.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@ def _calc_priors_and_means(X, weights, pi): ## M-step co-routine
4444
r = r / jnp.sum(r, axis=1, keepdims=True) ## responsibilities
4545
_pi = jnp.sum(r, axis=0, keepdims=True) / N ## calc new priors
4646
## calc weighted means (weighted by responsibilities)
47-
means = jnp.matmul(r.T, X) / jnp.sum(r, axis=0, keepdims=True).T
47+
Z = jnp.sum(r, axis=0, keepdims=True) ## partition function
48+
M = (Z > 0.) * 1.
49+
Z = Z * M + (1. + M) ## removes div-by-0 cases
50+
means = jnp.matmul(r.T, X) / Z.T
4851
return means, _pi, r
4952

5053
@partial(jit, static_argnums=[1])
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
from jax import numpy as jnp, random, jit, scipy
2+
from functools import partial
3+
import time, sys
4+
import numpy as np
5+
6+
from ngclearn.utils.density.mixture import Mixture
7+
8+
########################################################################################################################
9+
## internal routines for mixture model
10+
########################################################################################################################
11+
12+
@jit
13+
def _log_exponential_pdf(X, rate):
14+
"""
15+
Calculates the multivariate exponential log likelihood of a design matrix/dataset `X`, under a given parameter
16+
probability `p`.
17+
18+
Args:
19+
X: a design matrix (dataset) to compute the log likelihood of
20+
21+
rate: a parameter rate vector
22+
23+
Returns:
24+
the log likelihood (scalar) of this design matrix X
25+
"""
26+
#D = X.shape[1] * 1. ## get dimensionality
27+
## pdf(x; r) = r * np.exp(-r * x), where r is "rate"
28+
## log (r exp(-r x) ) = log(r) + log(exp(-r x) = log(r) - r x
29+
vec_ll = -(X * rate) + jnp.log(rate) ## log exponential
30+
log_ll = jnp.sum(vec_ll, axis=1, keepdims=True) ## get per-datapoint LL
31+
return log_ll
32+
33+
@jit
34+
def _calc_exponential_pdf_vals(X, p):
35+
log_ll = _log_exponential_pdf(X, p) ## get log-likelihood
36+
ll = jnp.exp(log_ll) ## likelihood
37+
return log_ll, ll
38+
39+
@jit
40+
def _calc_priors_and_rates(X, weights, pi): ## M-step co-routine
41+
## calc new rates, responsibilities, and priors given current stats
42+
N = X.shape[0] ## get number of samples
43+
## calc responsibilities
44+
r = (pi * weights)
45+
r = r / jnp.sum(r, axis=1, keepdims=True) ## responsibilities
46+
_pi = jnp.sum(r, axis=0, keepdims=True) / N ## calc new priors
47+
## calc weighted rates (weighted by responsibilities)
48+
Z = jnp.sum(r, axis=0, keepdims=True) ## calc partition function
49+
M = (Z > 0.) * 1.
50+
Z = Z * M + (1. - M) ## we mask out any zero partition function values
51+
rates = jnp.matmul(r.T, X) / Z.T
52+
return rates, _pi, r
53+
54+
@partial(jit, static_argnums=[1])
55+
def _sample_prior_weights(dkey, n_samples, pi): ## samples prior weighting parameters (of mixture)
56+
log_pi = jnp.log(pi) ## calc log(prior)
57+
lats = random.categorical(dkey, logits=log_pi, shape=(n_samples, 1)) ## sample components/latents
58+
return lats
59+
60+
@partial(jit, static_argnums=[1])
61+
def _sample_component(dkey, n_samples, rate): ## samples a component (of mixture)
62+
## sampling ~[exp(rx)] is same as r * [~exp(x)]
63+
eps = jax.random.exponential(dkey, shape=(n_samples, mu.shape[1])) * rate ## draw exponential samples
64+
return x_s
65+
66+
########################################################################################################################
67+
68+
class ExponentialMixture(Mixture): ## Exponential mixture model (mixture-of-exponentials)
69+
"""
70+
Implements a exponential mixture model (EMM) -- or mixture of exponentials (MoExp).
71+
Adaptation of parameters is conducted via the Expectation-Maximization (EM)
72+
learning algorithm. Note that this exponential mixture assumes that each component
73+
is a factorizable mutlivariate exponential distribution. (A Categorical distribution
74+
is assumed over the latent variables).
75+
76+
Args:
77+
K: the number of components/latent variables within this EMM
78+
79+
max_iter: the maximum number of EM iterations to fit parameters to data (Default = 50)
80+
81+
init_kmeans: <Unsupported>
82+
"""
83+
84+
def __init__(self, K, max_iter=50, init_kmeans=False, key=None, **kwargs):
85+
super().__init__(K, max_iter, **kwargs)
86+
self.K = K
87+
self.max_iter = int(max_iter)
88+
self.init_kmeans = init_kmeans ## Unsupported currently
89+
self.rate = [] ## component rate parameters
90+
self.pi = None ## prior weight parameters
91+
#self.z_weights = None # variables for parameterizing weights for SGD
92+
self.key = random.PRNGKey(time.time_ns()) if key is None else key
93+
94+
def init(self, X):
95+
"""
96+
Initializes this EMM in accordance to a supplied design matrix.
97+
98+
Args:
99+
X: the design matrix to initialize this EMM to
100+
101+
"""
102+
dim = X.shape[1]
103+
self.key, *skey = random.split(self.key, 3)
104+
self.pi = jnp.ones((1, self.K)) / (self.K * 1.)
105+
ptrs = random.permutation(skey[0], X.shape[0])
106+
self.rate = []
107+
for j in range(self.K):
108+
ptr = ptrs[j]
109+
self.key, *skey = random.split(self.key, 3)
110+
eps = random.uniform(skey[0], minval=0., maxval=0.5, shape=(1, dim)) ## jitter initial rate params
111+
self.rate.append(eps)
112+
113+
def calc_log_likelihood(self, X):
114+
"""
115+
Calculates the multivariate exponential log likelihood of a design matrix/dataset `X`, under the current
116+
parameters of this exponential mixture.
117+
118+
Args:
119+
X: the design matrix to estimate log likelihood values over under this EMM
120+
121+
Returns:
122+
(column) vector of individual log likelihoods, scalar for the complete log likelihood p(X)
123+
"""
124+
ll = 0.
125+
for j in range(self.K):
126+
log_ll_j, ll_j = _calc_exponential_pdf_vals(X, self.rate[j])
127+
ll = ll_j + ll
128+
log_ll = jnp.log(ll) ## vector of individual log p(x_n) values
129+
complete_ll = jnp.sum(log_ll) ## complete log-likelihood for design matrix X, i.e., log p(X)
130+
return log_ll, complete_ll
131+
132+
def _E_step(self, X): ## Expectation (E) step, co-routine
133+
weights = []
134+
for j in range(self.K):
135+
log_ll_j, ll_j = _calc_exponential_pdf_vals(X, self.rate[j])
136+
weights.append( ll_j )
137+
weights = jnp.concat(weights, axis=1)
138+
return weights ## data-dependent weights (intermediate responsibilities)
139+
140+
def _M_step(self, X, weights): ## Maximization (M) step, co-routine
141+
rates, pi, r = _calc_priors_and_rates(X, weights, self.pi)
142+
self.pi = pi ## store new prior parameters
143+
# calc weighted covariances
144+
for j in range(self.K):
145+
#r_j = r[:, j:j + 1]
146+
rate_j = rates[j:j + 1, :]
147+
self.rate[j] = rate_j ## store new rate(j) parameter
148+
return rates, r
149+
150+
def fit(self, X, tol=1e-3, verbose=False):
151+
"""
152+
Run full fitting process of this EMM.
153+
154+
Args:
155+
X: the dataset to fit this EMM to
156+
157+
tol: the tolerance value for detecting convergence (via difference-of-means); will engage in early-stopping
158+
if tol >= 0. (Default: 1e-3)
159+
160+
verbose: if True, this function will print out per-iteration measurements to I/O
161+
"""
162+
rates_prev = jnp.concat(self.rate, axis=0)
163+
for i in range(self.max_iter):
164+
self.update(X) ## carry out one E-step followed by an M-step
165+
rates = jnp.concat(self.rate, axis=0)
166+
dor = jnp.linalg.norm(rates - rates_prev) ## norm of difference-of-rates
167+
if verbose:
168+
print(f"{i}: Rate-diff = {dor}")
169+
#print(jnp.linalg.norm(rates - rates_prev))
170+
if tol >= 0. and dor < tol:
171+
print(f"Converged after {i + 1} iterations.")
172+
break
173+
rates_prev = rates
174+
175+
def update(self, X):
176+
"""
177+
Performs a single iterative update (E-step followed by M-step) of parameters (assuming model initialized)
178+
179+
Args:
180+
X: the dataset / design matrix to fit this BMM to
181+
"""
182+
r_w = self._E_step(X) ## carry out E-step
183+
rates, respon = self._M_step(X, r_w) ## carry out M-step
184+
185+
def sample(self, n_samples, mode_j=-1):
186+
"""
187+
Draw samples from the current underlying EMM model
188+
189+
Args:
190+
n_samples: the number of samples to draw from this EMM
191+
192+
mode_j: if >= 0, will only draw samples from a specific component of this EMM
193+
(Default = -1), ignoring the Categorical prior over latent variables/components
194+
195+
Returns:
196+
Design matrix of samples drawn under the distribution defined by this EMM
197+
"""
198+
## sample prior
199+
self.key, *skey = random.split(self.key, 3)
200+
if mode_j >= 0: ## sample from a particular mode / component
201+
rate_j = self.rate[mode_j]
202+
Xs = _sample_component(skey[0], n_samples=n_samples, rate=rate_j)
203+
else: ## sample from full mixture distribution
204+
## sample components/latents
205+
lats = _sample_prior_weights(skey[0], n_samples=n_samples, pi=self.pi)
206+
## then sample chosen component exponential
207+
Xs = []
208+
for j in range(self.K):
209+
freq_j = int(jnp.sum((lats == j))) ## compute frequency over mode
210+
self.key, *skey = random.split(self.key, 3)
211+
x_s = _sample_component(skey[0], n_samples=freq_j, rate=self.rate[j])
212+
Xs.append(x_s)
213+
Xs = jnp.concat(Xs, axis=0)
214+
return Xs
215+

ngclearn/utils/density/gaussianMixture.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@ def _calc_priors_and_means(X, weights, pi): ## M-step co-routine
6262
r = r / jnp.sum(r, axis=1, keepdims=True) ## responsibilities
6363
_pi = jnp.sum(r, axis=0, keepdims=True) / N ## calc new priors
6464
## calc weighted means (weighted by responsibilities)
65-
means = jnp.matmul(r.T, X) / jnp.sum(r, axis=0, keepdims=True).T
65+
Z = jnp.sum(r, axis=0, keepdims=True) ## partition function
66+
M = (Z > 0.) * 1.
67+
Z = Z * M + (1. + M) ## removes div-by-0 cases
68+
means = jnp.matmul(r.T, X) / Z.T
6669
return means, _pi, r
6770

6871
@partial(jit, static_argnums=[1])

0 commit comments

Comments
 (0)