@@ -820,10 +820,8 @@ def test_normal_mixture_nd(self, seeded_test, nd, ncomp):
820
820
mus = Normal ("mus" , shape = comp_shape )
821
821
taus = Gamma ("taus" , alpha = 1 , beta = 1 , shape = comp_shape )
822
822
ws = Dirichlet ("ws" , np .ones (ncomp ), shape = (ncomp ,))
823
- mixture0 = NormalMixture ("m" , w = ws , mu = mus , tau = taus , shape = nd , comp_shape = comp_shape )
824
- obs0 = NormalMixture (
825
- "obs" , w = ws , mu = mus , tau = taus , comp_shape = comp_shape , observed = observed
826
- )
823
+ mixture0 = NormalMixture ("m" , w = ws , mu = mus , tau = taus , shape = nd )
824
+ obs0 = NormalMixture ("obs" , w = ws , mu = mus , tau = taus , observed = observed )
827
825
828
826
with Model () as model1 :
829
827
mus = Normal ("mus" , shape = comp_shape )
@@ -867,7 +865,6 @@ def ref_rand(size, w, mu, sigma):
867
865
"mu" : Domain ([[0.05 , 2.5 ], [- 5.0 , 1.0 ]], edges = (None , None )),
868
866
"sigma" : Domain ([[1 , 1 ], [1.5 , 2.0 ]], edges = (None , None )),
869
867
},
870
- extra_args = {"comp_shape" : 2 },
871
868
size = 1000 ,
872
869
ref_rand = ref_rand ,
873
870
)
@@ -878,7 +875,6 @@ def ref_rand(size, w, mu, sigma):
878
875
"mu" : Domain ([[- 5.0 , 1.0 , 2.5 ]], edges = (None , None )),
879
876
"sigma" : Domain ([[1.5 , 2.0 , 3.0 ]], edges = (None , None )),
880
877
},
881
- extra_args = {"comp_shape" : 3 },
882
878
size = 1000 ,
883
879
ref_rand = ref_rand ,
884
880
)
@@ -902,7 +898,6 @@ def test_scalar_components(self):
902
898
w = np .ones (npop ) / npop ,
903
899
mu = mus ,
904
900
sigma = 1e-5 ,
905
- comp_shape = (nd , npop ),
906
901
shape = nd ,
907
902
)
908
903
z = Categorical ("z" , p = np .ones (npop ) / npop , shape = nd )
0 commit comments