@@ -74,6 +74,28 @@ def test_normal_mixture(self):
74
74
np .sort (self .norm_mu ),
75
75
rtol = 0.1 , atol = 0.1 )
76
76
77
+ def test_normal_mixture_nd (self ):
78
+ nd , ncomp = 3 , 5
79
+
80
+ with Model () as model0 :
81
+ mus = Normal ('mus' , shape = (nd , ncomp ))
82
+ taus = Gamma ('taus' , alpha = 1 , beta = 1 , shape = (nd , ncomp ))
83
+ ws = Dirichlet ('ws' , np .ones (ncomp ))
84
+ mixture0 = NormalMixture ('m' , w = ws , mu = mus , tau = taus , shape = nd )
85
+
86
+ with Model () as model1 :
87
+ mus = Normal ('mus' , shape = (nd , ncomp ))
88
+ taus = Gamma ('taus' , alpha = 1 , beta = 1 , shape = (nd , ncomp ))
89
+ ws = Dirichlet ('ws' , np .ones (ncomp ))
90
+ comp_dist = [Normal .dist (mu = mus [:, i ], tau = taus [:, i ])
91
+ for i in range (ncomp )]
92
+ mixture1 = Mixture ('m' , w = ws , comp_dists = comp_dist , shape = nd )
93
+
94
+ testpoint = model0 .test_point
95
+ testpoint ['mus' ] = np .random .randn (nd , ncomp )
96
+ assert_allclose (model0 .logp (testpoint ), model1 .logp (testpoint ))
97
+ assert_allclose (mixture0 .logp (testpoint ), mixture1 .logp (testpoint ))
98
+
77
99
def test_poisson_mixture (self ):
78
100
with Model () as model :
79
101
w = Dirichlet ('w' , floatX (np .ones_like (self .pois_w )))
0 commit comments