Skip to content

Commit 7d327b3

Browse files
Junpeng Laohwassner
andauthored
Mixture of mixtures (#2904)
* Mixture of mixtures Following the discussion on Discourse: https://discourse.pymc.io/t/how-can-we-build-a-mixture-of-mixtures/910/ I made some small fix so that it is easier to create multivariate mixture and mixture of mixtures * fix test * add test for multivariate mixture * fix float32 test Co-authored-by: Junpeng Lao <[email protected]> Co-authored-by: Hubert Wassner <[email protected]>
1 parent 6dd0e81 commit 7d327b3

File tree

4 files changed

+123
-12
lines changed

4 files changed

+123
-12
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
- Fix the Binomial likelihood in `.glm.families.Binomial`, with the flexibility of specifying the `n`.
2020
- Add `offset` kwarg to `.glm`.
2121
- Changed the `compare` function to accept a dictionary of model-trace pairs instead of two separate lists of models and traces.
22+
- add test and support for creating multivariate mixture and mixture of mixtures
2223

2324
### Fixes
2425

pymc3/distributions/mixture.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class Mixture(Distribution):
6767
6868
like = pm.Mixture('like', w=w, comp_dists = [pois1, pois2], observed=data)
6969
"""
70+
7071
def __init__(self, w, comp_dists, *args, **kwargs):
7172
shape = kwargs.pop('shape', ())
7273

@@ -95,7 +96,7 @@ def __init__(self, w, comp_dists, *args, **kwargs):
9596

9697
if 'mode' not in defaults:
9798
defaults.append('mode')
98-
except AttributeError:
99+
except (AttributeError, ValueError):
99100
pass
100101

101102
super(Mixture, self).__init__(shape, dtype, defaults=defaults,
@@ -109,22 +110,25 @@ def _comp_logp(self, value):
109110

110111
return comp_dists.logp(value_)
111112
except AttributeError:
112-
return tt.stack([comp_dist.logp(value) for comp_dist in comp_dists],
113-
axis=1)
113+
return tt.squeeze(tt.stack([comp_dist.logp(value)
114+
for comp_dist in comp_dists],
115+
axis=1))
114116

115117
def _comp_means(self):
116118
try:
117119
return tt.as_tensor_variable(self.comp_dists.mean)
118120
except AttributeError:
119-
return tt.stack([comp_dist.mean for comp_dist in self.comp_dists],
120-
axis=1)
121+
return tt.squeeze(tt.stack([comp_dist.mean
122+
for comp_dist in self.comp_dists],
123+
axis=1))
121124

122125
def _comp_modes(self):
123126
try:
124127
return tt.as_tensor_variable(self.comp_dists.mode)
125128
except AttributeError:
126-
return tt.stack([comp_dist.mode for comp_dist in self.comp_dists],
127-
axis=1)
129+
return tt.squeeze(tt.stack([comp_dist.mode
130+
for comp_dist in self.comp_dists],
131+
axis=1))
128132

129133
def _comp_samples(self, point=None, size=None, repeat=None):
130134
try:
@@ -196,15 +200,16 @@ class NormalMixture(Mixture):
196200
197201
Note: You only have to pass in sd or tau, but not both.
198202
"""
203+
199204
def __init__(self, w, mu, *args, **kwargs):
200205
_, sd = get_tau_sd(tau=kwargs.pop('tau', None),
201206
sd=kwargs.pop('sd', None))
202-
207+
203208
distshape = np.broadcast(mu, sd).shape
204209
self.mu = mu = tt.as_tensor_variable(mu)
205210
self.sd = sd = tt.as_tensor_variable(sd)
206211

207-
if not distshape:
212+
if not distshape:
208213
distshape = np.broadcast(mu.tag.test_value, sd.tag.test_value).shape
209214

210215
super(NormalMixture, self).__init__(w, Normal.dist(mu, sd=sd, shape=distshape),

pymc3/stats.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import pymc3 as pm
1212
from pymc3.theanof import floatX
1313

14-
from scipy.misc import logsumexp
14+
from scipy.special import logsumexp
1515
from scipy.stats import dirichlet
1616
from scipy.optimize import minimize
1717
from scipy.signal import fftconvolve

pymc3/tests/test_mixture.py

Lines changed: 107 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
from numpy.testing import assert_allclose
33

44
from .helpers import SeededTest
5-
from pymc3 import Dirichlet, Gamma, Metropolis, Mixture, Model, Normal, NormalMixture, Poisson, sample
5+
from pymc3 import Dirichlet, Gamma, Normal, Lognormal, Poisson, Exponential, \
6+
Mixture, NormalMixture, MvNormal, sample, Metropolis, Model
7+
import scipy.stats as st
8+
from scipy.special import logsumexp
69
from pymc3.theanof import floatX
7-
10+
import theano
811

912
# Generate data
1013
def generate_normal_mixture_data(w, mu, sd, size=1000):
@@ -104,3 +107,105 @@ def test_mixture_list_of_poissons(self):
104107
assert_allclose(np.sort(trace['mu'].mean(axis=0)),
105108
np.sort(self.pois_mu),
106109
rtol=0.1, atol=0.1)
110+
111+
def test_mixture_of_mvn(self):
112+
mu1 = np.asarray([0., 1.])
113+
cov1 = np.diag([1.5, 2.5])
114+
mu2 = np.asarray([1., 0.])
115+
cov2 = np.diag([2.5, 3.5])
116+
obs = np.asarray([[.5, .5], mu1, mu2])
117+
with Model() as model:
118+
w = Dirichlet('w', floatX(np.ones(2)), transform=None)
119+
mvncomp1 = MvNormal.dist(mu=mu1, cov=cov1)
120+
mvncomp2 = MvNormal.dist(mu=mu2, cov=cov2)
121+
y = Mixture('x_obs', w, [mvncomp1, mvncomp2],
122+
observed=obs)
123+
124+
# check logp of each component
125+
complogp_st = np.vstack((st.multivariate_normal.logpdf(obs, mu1, cov1),
126+
st.multivariate_normal.logpdf(obs, mu2, cov2))
127+
).T
128+
complogp = y.distribution._comp_logp(theano.shared(obs)).eval()
129+
assert_allclose(complogp, complogp_st)
130+
131+
# check logp of mixture
132+
testpoint = model.test_point
133+
mixlogp_st = logsumexp(np.log(testpoint['w']) + complogp_st,
134+
axis=-1, keepdims=True)
135+
assert_allclose(y.logp_elemwise(testpoint),
136+
mixlogp_st)
137+
138+
# check logp of model
139+
priorlogp = st.dirichlet.logpdf(x=testpoint['w'],
140+
alpha=np.ones(2),
141+
)
142+
assert_allclose(model.logp(testpoint),
143+
mixlogp_st.sum() + priorlogp)
144+
145+
def test_mixture_of_mixture(self):
146+
nbr = 4
147+
with Model() as model:
148+
# mixtures components
149+
g_comp = Normal.dist(
150+
mu=Exponential('mu_g', lam=1.0, shape=nbr, transform=None),
151+
sd=1,
152+
shape=nbr)
153+
l_comp = Lognormal.dist(
154+
mu=Exponential('mu_l', lam=1.0, shape=nbr, transform=None),
155+
sd=1,
156+
shape=nbr)
157+
# weight vector for the mixtures
158+
g_w = Dirichlet('g_w', a=floatX(np.ones(nbr)*0.0000001), transform=None)
159+
l_w = Dirichlet('l_w', a=floatX(np.ones(nbr)*0.0000001), transform=None)
160+
# mixture components
161+
g_mix = Mixture.dist(w=g_w, comp_dists=g_comp)
162+
l_mix = Mixture.dist(w=l_w, comp_dists=l_comp)
163+
# mixture of mixtures
164+
mix_w = Dirichlet('mix_w', a=floatX(np.ones(2)), transform=None)
165+
mix = Mixture('mix', w=mix_w,
166+
comp_dists=[g_mix, l_mix],
167+
observed=np.exp(self.norm_x))
168+
169+
test_point = model.test_point
170+
171+
def mixmixlogp(value, point):
172+
priorlogp = st.dirichlet.logpdf(x=point['g_w'],
173+
alpha=np.ones(nbr)*0.0000001,
174+
) + \
175+
st.expon.logpdf(x=point['mu_g']).sum() + \
176+
st.dirichlet.logpdf(x=point['l_w'],
177+
alpha=np.ones(nbr)*0.0000001,
178+
) + \
179+
st.expon.logpdf(x=point['mu_l']).sum() + \
180+
st.dirichlet.logpdf(x=point['mix_w'],
181+
alpha=np.ones(2),
182+
)
183+
complogp1 = st.norm.logpdf(x=value,
184+
loc=point['mu_g'])
185+
mixlogp1 = logsumexp(np.log(point['g_w']) + complogp1,
186+
axis=-1, keepdims=True)
187+
complogp2 = st.lognorm.logpdf(value, 1., 0., np.exp(point['mu_l']))
188+
mixlogp2 = logsumexp(np.log(point['l_w']) + complogp2,
189+
axis=-1, keepdims=True)
190+
complogp_mix = np.concatenate((mixlogp1, mixlogp2), axis=1)
191+
mixmixlogpg = logsumexp(np.log(point['mix_w']) + complogp_mix,
192+
axis=-1, keepdims=True)
193+
return priorlogp, mixmixlogpg
194+
195+
value = np.exp(self.norm_x)[:, None]
196+
priorlogp, mixmixlogpg = mixmixlogp(value, test_point)
197+
198+
# check logp of mixture
199+
assert_allclose(mixmixlogpg, mix.logp_elemwise(test_point))
200+
201+
# check model logp
202+
assert_allclose(priorlogp + mixmixlogpg.sum(),
203+
model.logp(test_point))
204+
205+
# check input and check logp again
206+
test_point['g_w'] = np.asarray([.1, .1, .2, .6])
207+
test_point['mu_g'] = np.exp(np.random.randn(nbr))
208+
priorlogp, mixmixlogpg = mixmixlogp(value, test_point)
209+
assert_allclose(mixmixlogpg, mix.logp_elemwise(test_point))
210+
assert_allclose(priorlogp + mixmixlogpg.sum(),
211+
model.logp(test_point))

0 commit comments

Comments
 (0)