|
2 | 2 | from numpy.testing import assert_allclose
|
3 | 3 |
|
4 | 4 | from .helpers import SeededTest
|
5 |
| -from pymc3 import Dirichlet, Gamma, Metropolis, Mixture, Model, Normal, NormalMixture, Poisson, sample |
| 5 | +from pymc3 import Dirichlet, Gamma, Normal, Lognormal, Poisson, Exponential, \ |
| 6 | + Mixture, NormalMixture, MvNormal, sample, Metropolis, Model |
| 7 | +import scipy.stats as st |
| 8 | +from scipy.special import logsumexp |
6 | 9 | from pymc3.theanof import floatX
|
7 |
| - |
| 10 | +import theano |
8 | 11 |
|
9 | 12 | # Generate data
|
10 | 13 | def generate_normal_mixture_data(w, mu, sd, size=1000):
|
@@ -104,3 +107,105 @@ def test_mixture_list_of_poissons(self):
|
104 | 107 | assert_allclose(np.sort(trace['mu'].mean(axis=0)),
|
105 | 108 | np.sort(self.pois_mu),
|
106 | 109 | rtol=0.1, atol=0.1)
|
| 110 | + |
| 111 | + def test_mixture_of_mvn(self): |
| 112 | + mu1 = np.asarray([0., 1.]) |
| 113 | + cov1 = np.diag([1.5, 2.5]) |
| 114 | + mu2 = np.asarray([1., 0.]) |
| 115 | + cov2 = np.diag([2.5, 3.5]) |
| 116 | + obs = np.asarray([[.5, .5], mu1, mu2]) |
| 117 | + with Model() as model: |
| 118 | + w = Dirichlet('w', floatX(np.ones(2)), transform=None) |
| 119 | + mvncomp1 = MvNormal.dist(mu=mu1, cov=cov1) |
| 120 | + mvncomp2 = MvNormal.dist(mu=mu2, cov=cov2) |
| 121 | + y = Mixture('x_obs', w, [mvncomp1, mvncomp2], |
| 122 | + observed=obs) |
| 123 | + |
| 124 | + # check logp of each component |
| 125 | + complogp_st = np.vstack((st.multivariate_normal.logpdf(obs, mu1, cov1), |
| 126 | + st.multivariate_normal.logpdf(obs, mu2, cov2)) |
| 127 | + ).T |
| 128 | + complogp = y.distribution._comp_logp(theano.shared(obs)).eval() |
| 129 | + assert_allclose(complogp, complogp_st) |
| 130 | + |
| 131 | + # check logp of mixture |
| 132 | + testpoint = model.test_point |
| 133 | + mixlogp_st = logsumexp(np.log(testpoint['w']) + complogp_st, |
| 134 | + axis=-1, keepdims=True) |
| 135 | + assert_allclose(y.logp_elemwise(testpoint), |
| 136 | + mixlogp_st) |
| 137 | + |
| 138 | + # check logp of model |
| 139 | + priorlogp = st.dirichlet.logpdf(x=testpoint['w'], |
| 140 | + alpha=np.ones(2), |
| 141 | + ) |
| 142 | + assert_allclose(model.logp(testpoint), |
| 143 | + mixlogp_st.sum() + priorlogp) |
| 144 | + |
| 145 | + def test_mixture_of_mixture(self): |
| 146 | + nbr = 4 |
| 147 | + with Model() as model: |
| 148 | + # mixtures components |
| 149 | + g_comp = Normal.dist( |
| 150 | + mu=Exponential('mu_g', lam=1.0, shape=nbr, transform=None), |
| 151 | + sd=1, |
| 152 | + shape=nbr) |
| 153 | + l_comp = Lognormal.dist( |
| 154 | + mu=Exponential('mu_l', lam=1.0, shape=nbr, transform=None), |
| 155 | + sd=1, |
| 156 | + shape=nbr) |
| 157 | + # weight vector for the mixtures |
| 158 | + g_w = Dirichlet('g_w', a=floatX(np.ones(nbr)*0.0000001), transform=None) |
| 159 | + l_w = Dirichlet('l_w', a=floatX(np.ones(nbr)*0.0000001), transform=None) |
| 160 | + # mixture components |
| 161 | + g_mix = Mixture.dist(w=g_w, comp_dists=g_comp) |
| 162 | + l_mix = Mixture.dist(w=l_w, comp_dists=l_comp) |
| 163 | + # mixture of mixtures |
| 164 | + mix_w = Dirichlet('mix_w', a=floatX(np.ones(2)), transform=None) |
| 165 | + mix = Mixture('mix', w=mix_w, |
| 166 | + comp_dists=[g_mix, l_mix], |
| 167 | + observed=np.exp(self.norm_x)) |
| 168 | + |
| 169 | + test_point = model.test_point |
| 170 | + |
| 171 | + def mixmixlogp(value, point): |
| 172 | + priorlogp = st.dirichlet.logpdf(x=point['g_w'], |
| 173 | + alpha=np.ones(nbr)*0.0000001, |
| 174 | + ) + \ |
| 175 | + st.expon.logpdf(x=point['mu_g']).sum() + \ |
| 176 | + st.dirichlet.logpdf(x=point['l_w'], |
| 177 | + alpha=np.ones(nbr)*0.0000001, |
| 178 | + ) + \ |
| 179 | + st.expon.logpdf(x=point['mu_l']).sum() + \ |
| 180 | + st.dirichlet.logpdf(x=point['mix_w'], |
| 181 | + alpha=np.ones(2), |
| 182 | + ) |
| 183 | + complogp1 = st.norm.logpdf(x=value, |
| 184 | + loc=point['mu_g']) |
| 185 | + mixlogp1 = logsumexp(np.log(point['g_w']) + complogp1, |
| 186 | + axis=-1, keepdims=True) |
| 187 | + complogp2 = st.lognorm.logpdf(value, 1., 0., np.exp(point['mu_l'])) |
| 188 | + mixlogp2 = logsumexp(np.log(point['l_w']) + complogp2, |
| 189 | + axis=-1, keepdims=True) |
| 190 | + complogp_mix = np.concatenate((mixlogp1, mixlogp2), axis=1) |
| 191 | + mixmixlogpg = logsumexp(np.log(point['mix_w']) + complogp_mix, |
| 192 | + axis=-1, keepdims=True) |
| 193 | + return priorlogp, mixmixlogpg |
| 194 | + |
| 195 | + value = np.exp(self.norm_x)[:, None] |
| 196 | + priorlogp, mixmixlogpg = mixmixlogp(value, test_point) |
| 197 | + |
| 198 | + # check logp of mixture |
| 199 | + assert_allclose(mixmixlogpg, mix.logp_elemwise(test_point)) |
| 200 | + |
| 201 | + # check model logp |
| 202 | + assert_allclose(priorlogp + mixmixlogpg.sum(), |
| 203 | + model.logp(test_point)) |
| 204 | + |
| 205 | + # check input and check logp again |
| 206 | + test_point['g_w'] = np.asarray([.1, .1, .2, .6]) |
| 207 | + test_point['mu_g'] = np.exp(np.random.randn(nbr)) |
| 208 | + priorlogp, mixmixlogpg = mixmixlogp(value, test_point) |
| 209 | + assert_allclose(mixmixlogpg, mix.logp_elemwise(test_point)) |
| 210 | + assert_allclose(priorlogp + mixmixlogpg.sum(), |
| 211 | + model.logp(test_point)) |
0 commit comments