Skip to content

Commit 7c0571a

Browse files
Junpeng LaoColCarroll
authored andcommitted
Fix 3044 (#3046)
Mixture multivariate shape issue. close #3044 and add test
1 parent 12694dd commit 7c0571a

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

pymc3/distributions/mixture.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def __init__(self, w, comp_dists, *args, **kwargs):
9696

9797
if 'mode' not in defaults:
9898
defaults.append('mode')
99-
except (AttributeError, ValueError):
99+
except (AttributeError, ValueError, IndexError):
100100
pass
101101

102102
super(Mixture, self).__init__(shape, dtype, defaults=defaults,

pymc3/tests/test_mixture.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,28 @@ def test_normal_mixture(self):
7474
np.sort(self.norm_mu),
7575
rtol=0.1, atol=0.1)
7676

77+
def test_normal_mixture_nd(self):
78+
nd, ncomp = 3, 5
79+
80+
with Model() as model0:
81+
mus = Normal('mus', shape=(nd, ncomp))
82+
taus = Gamma('taus', alpha=1, beta=1, shape=(nd, ncomp))
83+
ws = Dirichlet('ws', np.ones(ncomp))
84+
mixture0 = NormalMixture('m', w=ws, mu=mus, tau=taus, shape=nd)
85+
86+
with Model() as model1:
87+
mus = Normal('mus', shape=(nd, ncomp))
88+
taus = Gamma('taus', alpha=1, beta=1, shape=(nd, ncomp))
89+
ws = Dirichlet('ws', np.ones(ncomp))
90+
comp_dist = [Normal.dist(mu=mus[:, i], tau=taus[:, i])
91+
for i in range(ncomp)]
92+
mixture1 = Mixture('m', w=ws, comp_dists=comp_dist, shape=nd)
93+
94+
testpoint = model0.test_point
95+
testpoint['mus'] = np.random.randn(nd, ncomp)
96+
assert_allclose(model0.logp(testpoint), model1.logp(testpoint))
97+
assert_allclose(mixture0.logp(testpoint), mixture1.logp(testpoint))
98+
7799
def test_poisson_mixture(self):
78100
with Model() as model:
79101
w = Dirichlet('w', floatX(np.ones_like(self.pois_w)))

0 commit comments

Comments
 (0)