Skip to content

Commit 62b341f

Browse files
committed
better test
1 parent 899e2ef commit 62b341f

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

pymc3/tests/test_variational_inference.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
ADVI, FullRankADVI, SVGD, NFVI, ASVGD,
1717
fit
1818
)
19-
from pymc3.variational import flows
19+
from pymc3.variational import flows, opvi
2020
from pymc3.variational.opvi import Approximation, Group
2121

2222
from . import models
@@ -79,7 +79,7 @@ def test_tracker_callback():
7979
@pytest.fixture('module')
8080
def three_var_model():
8181
with pm.Model() as model:
82-
pm.Normal('one', shape=(10, 2))
82+
pm.HalfNormal('one', shape=(10, 2))
8383
pm.Normal('two', shape=(10, ))
8484
pm.Normal('three', shape=(10, 1, 2))
8585
return model
@@ -174,7 +174,7 @@ def parametric_grouped_approxes(request):
174174

175175
@pytest.fixture
176176
def three_var_aevb_groups(parametric_grouped_approxes, three_var_model, aevb_initial):
177-
dsize = np.prod(three_var_model.one.dshape[1:])
177+
dsize = np.prod(opvi.get_transformed(three_var_model.one).dshape[1:])
178178
cls, kw = parametric_grouped_approxes
179179
spec = cls.get_param_spec_for(d=dsize, **kw)
180180
params = dict()
@@ -610,11 +610,11 @@ def test_fit_fn_text(method, kwargs, error, another_simple_model):
610610
@pytest.fixture('module')
611611
def aevb_model():
612612
with pm.Model() as model:
613-
pm.Normal('x', shape=(2,))
613+
pm.HalfNormal('x', shape=(2,))
614614
pm.Normal('y', shape=(2,))
615615
x = model.x
616616
y = model.y
617-
mu = theano.shared(x.init_value) * 2
617+
mu = theano.shared(x.init_value)
618618
rho = theano.shared(np.zeros_like(x.init_value))
619619
return {
620620
'model': model,
@@ -632,8 +632,8 @@ def test_aevb(inference_spec, aevb_model):
632632
replace = aevb_model['replace']
633633
with model:
634634
try:
635-
inference = inference_spec(local_rv={x: replace})
636-
approx = inference.fit(3, obj_n_mc=2)
635+
inference = inference_spec(local_rv={x: {'mu': replace['mu']*5, 'rho': replace['rho']}})
636+
approx = inference.fit(3, obj_n_mc=2, more_obj_params=replace.values())
637637
approx.sample(10)
638638
approx.sample_node(
639639
y,

0 commit comments

Comments
 (0)