11from jax import numpy as jnp , random , jit , scipy
22from functools import partial
33import time , sys
4- import numpy as np
54
65from ngclearn .utils .density .mixture import Mixture
76
87########################################################################################################################
98## internal routines for mixture model
109########################################################################################################################
11-
1210@jit
13- def _log_exponential_pdf (X , rate ):
11+ def _log_exponential_pdf (X , lmbda ):
1412 """
15- Calculates the multivariate exponential log likelihood of a design matrix/dataset `X`, under a given parameter
13+ Calculates the multivariate exponential log likelihood of a design matrix/dataset `X`, under a given parameter
1614 probability `p`.
1715
1816 Args:
1917 X: a design matrix (dataset) to compute the log likelihood of
2018
21- rate : a parameter rate vector
19+ lmbda : a parameter rate vector
2220
2321 Returns:
2422 the log likelihood (scalar) of this design matrix X
2523 """
26- #D = X.shape[1] * 1. ## get dimensionality
27- ## pdf(x; r) = r * np.exp(-r * x), where r is "rate"
28- ## log (r exp(-r x) ) = log(r) + log(exp(-r x) = log(r) - r x
29- vec_ll = - (X * rate ) + jnp .log (rate ) ## log exponential
30- log_ll = jnp .sum (vec_ll , axis = 1 , keepdims = True ) ## get per-datapoint LL
31- return log_ll
24+ log_pdf = - jnp .matmul (X , lmbda .T ) + jnp .sum (jnp .log (lmbda .T ), axis = 0 )
25+ return log_pdf
3226
3327@jit
34- def _calc_exponential_pdf_vals (X , p ):
35- log_ll = _log_exponential_pdf (X , p ) ## get log-likelihood
36- ll = jnp .exp (log_ll ) ## likelihood
37- return log_ll , ll
38-
39- #@jit
40- def _calc_priors_and_rates (X , weights , pi ): ## M-step co-routine
41- ## calc new rates, responsibilities, and priors given current stats
42- N = X .shape [0 ] ## get number of samples
43- ## calc responsibilities
44- r = (pi * weights )
45- r = r / jnp .sum (r , axis = 1 , keepdims = True ) ## responsibilities
46- _pi = jnp .sum (r , axis = 0 , keepdims = True ) / N ## calc new priors
47- ## calc weighted rates (weighted by responsibilities)
48-
49- Znum = jnp .sum (r , axis = 0 , keepdims = True )
50- #print(Znum.shape)
51- Zden = jnp .matmul (r .T , X )
52- rates = Znum .T / Zden
53- #print(Zden.shape)
54- #exit()
55- """
56- Z = jnp.sum(r, axis=0, keepdims=True) ## calc partition function
57- Ndata = jnp.matmul(r.T, X)
58- M = (Ndata > 0.) * 1.
59- Ndata = Ndata * M + (1. - M) ## we mask out division-by-0 cases
60- rates = Z.T / Ndata
61- """
62- return rates , _pi , r
28+ def _calc_exponential_mixture_stats (X , lmbda , pi ):
29+ log_exp_pdf = _log_exponential_pdf (X , lmbda )
30+ log_likeli = log_exp_pdf + jnp .log (pi ) ## raw log-likelihood
31+ likeli = jnp .exp (log_likeli ) ## raw likelihood
32+ gamma = likeli / jnp .sum (likeli , axis = 1 , keepdims = True ) ## responsibilities
33+ weighted_log_likeli = jnp .sum (log_likeli * gamma , axis = 1 , keepdims = True ) ## get weighted EMM log-likelihood
34+ complete_loglikeli = jnp .sum (weighted_log_likeli ) ## complete log-likelihood for design matrix X, i.e., log p(X)
35+ return log_likeli , likeli , gamma , weighted_log_likeli , complete_loglikeli
36+
37+ @jit
38+ def _calc_priors_and_rates (X , weights , pi ): ## M-step co-routine
39+ ## compute updates to pi params
40+ Zk = jnp .sum (weights , axis = 0 , keepdims = True ) ## summed weights/responsibilities; 1 x K
41+ Z = jnp .sum (Zk ) ## partition function
42+ pi = Zk / Z
43+ ## compute updates to lmbda params
44+ Z = jnp .matmul (weights .T , X )
45+ lmbda = Zk .T / Z
46+ return pi , lmbda
6347
6448@partial (jit , static_argnums = [1 ])
6549def _sample_prior_weights (dkey , n_samples , pi ): ## samples prior weighting parameters (of mixture)
@@ -70,18 +54,17 @@ def _sample_prior_weights(dkey, n_samples, pi): ## samples prior weighting param
7054@partial (jit , static_argnums = [1 ])
7155def _sample_component (dkey , n_samples , rate ): ## samples a component (of mixture)
7256 ## sampling ~[exp(rx)] is same as r * [~exp(x)]
73- eps = jax . random .exponential (dkey , shape = (n_samples , mu .shape [1 ])) * rate ## draw exponential samples
57+ x_s = random .exponential (dkey , shape = (n_samples , rate .shape [1 ])) * rate ## draw exponential samples
7458 return x_s
7559
7660########################################################################################################################
7761
7862class ExponentialMixture (Mixture ): ## Exponential mixture model (mixture-of-exponentials)
7963 """
80- Implements a exponential mixture model (EMM) -- or mixture of exponentials (MoExp).
81- Adaptation of parameters is conducted via the Expectation-Maximization (EM)
82- learning algorithm. Note that this exponential mixture assumes that each component
83- is a factorizable mutlivariate exponential distribution. (A Categorical distribution
84- is assumed over the latent variables).
64+ Implements a exponential mixture model (EMM) -- or mixture of exponentials (MoExp). Adaptation of parameters is
65+ conducted via the Expectation-Maximization (EM) learning algorithm. Note that this exponential mixture assumes that
66+ each component is a factorizable mutlivariate exponential distribution. (A Categorical distribution is assumed over
67+ the latent variables).
8568
8669 Args:
8770 K: the number of components/latent variables within this EMM
@@ -110,15 +93,20 @@ def init(self, X):
11093
11194 """
11295 dim = X .shape [1 ]
113- self .key , * skey = random .split (self .key , 3 )
114- self .pi = jnp .ones ((1 , self .K )) / (self .K * 1. )
115- ptrs = random .permutation (skey [0 ], X .shape [0 ])
96+ self .key , * skey = random .split (self .key , 4 )
97+ ## Computed jittered initial phi param values
98+ #self.pi = jnp.ones((1, self.K)) / (self.K * 1.)
99+ pi = jnp .ones ((1 , self .K ))
100+ eps = random .uniform (skey [0 ], minval = 0.99 , maxval = 1.01 , shape = (1 , self .K ))
101+ pi = pi * eps
102+ self .pi = pi / jnp .sum (pi )
103+
104+ ## Computed jittered initial rate (lmbda) param values
105+ lmbda_h = 1.0 / jnp .mean (X , axis = 0 , keepdims = True )
106+ lmbda = random .uniform (skey [1 ], minval = 0.99 , maxval = 1.01 , shape = (self .K , dim )) * lmbda_h
116107 self .rate = []
117- for j in range (self .K ):
118- ptr = ptrs [j ]
119- self .key , * skey = random .split (self .key , 3 )
120- eps = random .uniform (skey [0 ], minval = 0.99 , maxval = 1.01 , shape = (1 , dim )) ## jitter initial rate params
121- self .rate .append (eps )
108+ for j in range (self .K ): ## set rates/lmbdas
109+ self .rate .append (lmbda [j :j + 1 , :])
122110
123111 def calc_log_likelihood (self , X ):
124112 """
@@ -131,31 +119,26 @@ def calc_log_likelihood(self, X):
131119 Returns:
132120 (column) vector of individual log likelihoods, scalar for the complete log likelihood p(X)
133121 """
134- ll = 0.
135- for j in range (self .K ):
136- log_ll_j , ll_j = _calc_exponential_pdf_vals (X , self .rate [j ])
137- ll = ll_j + ll
138- log_ll = jnp .log (ll ) ## vector of individual log p(x_n) values
139- complete_ll = jnp .sum (log_ll ) ## complete log-likelihood for design matrix X, i.e., log p(X)
140- return log_ll , complete_ll
122+ pi = self .pi ## get prior weight values
123+ lmbda = jnp .concat (self .rate , axis = 0 ) ## get rates as a block matrix
124+ ## compute relevant log-likelihoods/likelihoods
125+ log_ll , ll , gamma , weighted_loglikeli , complete_likeli = _calc_exponential_mixture_stats (X , lmbda , pi )
126+ return weighted_loglikeli , complete_likeli
141127
142128 def _E_step (self , X ): ## Expectation (E) step, co-routine
143- weights = []
144- for j in range (self .K ):
145- log_ll_j , ll_j = _calc_exponential_pdf_vals (X , self .rate [j ])
146- weights .append ( ll_j )
147- weights = jnp .concat (weights , axis = 1 )
148- return weights ## data-dependent weights (intermediate responsibilities)
129+ pi = self .pi ## get prior weight values
130+ lmbda = jnp .concat (self .rate , axis = 0 ) ## get rates as a block matrix
131+ _ , _ , gamma , weighted_loglikeli , complete_likeli = _calc_exponential_mixture_stats (X , lmbda , pi )
132+ ## Note: responsibility weights gamma have shape => N x K
133+ return gamma , weighted_loglikeli , complete_likeli
149134
150135 def _M_step (self , X , weights ): ## Maximization (M) step, co-routine
151- rates , pi , r = _calc_priors_and_rates (X , weights , self .pi )
152- self .pi = pi ## store new prior parameters
153- # calc weighted covariances
154- for j in range (self .K ):
155- #r_j = r[:, j:j + 1]
156- rate_j = rates [j :j + 1 , :]
157- self .rate [j ] = rate_j ## store new rate(j) parameter
158- return rates , r
136+ ## compute updates to pi and lmbda params
137+ pi , lmbda = _calc_priors_and_rates (X , weights , self .pi )
138+ self .pi = pi ## store new prior parameters
139+ for j in range (self .K ): ## store new rate/lmbda parameters
140+ self .rate [j ] = lmbda [j :j + 1 , :]
141+ return pi , lmbda
159142
160143 def fit (self , X , tol = 1e-3 , verbose = False ):
161144 """
@@ -171,11 +154,11 @@ def fit(self, X, tol=1e-3, verbose=False):
171154 """
172155 rates_prev = jnp .concat (self .rate , axis = 0 )
173156 for i in range (self .max_iter ):
174- self .update (X ) ## carry out one E-step followed by an M-step
175- rates = jnp .concat (self .rate , axis = 0 )
157+ gamma , pi , rates , complete_loglikeli = self .update (X ) ## carry out one E-step followed by an M-step
158+ # rates = jnp.concat(self.rate, axis=0)
176159 dor = jnp .linalg .norm (rates - rates_prev ) ## norm of difference-of-rates
177160 if verbose :
178- print (f"{ i } : Rate-diff = { dor } " )
161+ print (f"{ i } : Rate-diff = { dor } log(p(X)) = { complete_loglikeli } nats " )
179162 #print(jnp.linalg.norm(rates - rates_prev))
180163 if tol >= 0. and dor < tol :
181164 print (f"Converged after { i + 1 } iterations." )
@@ -188,9 +171,13 @@ def update(self, X):
188171
189172 Args:
190173 X: the dataset / design matrix to fit this BMM to
174+
175+ Returns:
176+ responsibilities (gamma), priors (pi), rates (lambda), EMM log-likelihood
191177 """
192- r_w = self ._E_step (X ) ## carry out E-step
193- rates , respon = self ._M_step (X , r_w ) ## carry out M-step
178+ gamma , _ , complete_log_likeli = self ._E_step (X ) ## carry out E-step
179+ pi , rates = self ._M_step (X , gamma ) ## carry out M-step
180+ return gamma , pi , rates , complete_log_likeli
194181
195182 def sample (self , n_samples , mode_j = - 1 ):
196183 """
@@ -205,15 +192,14 @@ def sample(self, n_samples, mode_j=-1):
205192 Returns:
206193 Design matrix of samples drawn under the distribution defined by this EMM
207194 """
208- ## sample prior
209195 self .key , * skey = random .split (self .key , 3 )
210- if mode_j >= 0 : ## sample from a particular mode / component
211- rate_j = self .rate [mode_j ]
196+ if mode_j >= 0 : ## sample from a particular mode
197+ rate_j = self .rate [mode_j ] ## directly select a specific component
212198 Xs = _sample_component (skey [0 ], n_samples = n_samples , rate = rate_j )
213199 else : ## sample from full mixture distribution
214- ## sample components/latents
200+ ## sample (prior) components/latents
215201 lats = _sample_prior_weights (skey [0 ], n_samples = n_samples , pi = self .pi )
216- ## then sample chosen component exponential
202+ ## then sample chosen component exponential(s)
217203 Xs = []
218204 for j in range (self .K ):
219205 freq_j = int (jnp .sum ((lats == j ))) ## compute frequency over mode
0 commit comments