16
16
ADVI , FullRankADVI , SVGD , NFVI , ASVGD ,
17
17
fit
18
18
)
19
- from pymc3 .variational import flows
19
+ from pymc3 .variational import flows , opvi
20
20
from pymc3 .variational .opvi import Approximation , Group
21
21
22
22
from . import models
@@ -79,7 +79,7 @@ def test_tracker_callback():
79
79
@pytest .fixture ('module' )
80
80
def three_var_model ():
81
81
with pm .Model () as model :
82
- pm .Normal ('one' , shape = (10 , 2 ))
82
+ pm .HalfNormal ('one' , shape = (10 , 2 ))
83
83
pm .Normal ('two' , shape = (10 , ))
84
84
pm .Normal ('three' , shape = (10 , 1 , 2 ))
85
85
return model
@@ -174,7 +174,7 @@ def parametric_grouped_approxes(request):
174
174
175
175
@pytest .fixture
176
176
def three_var_aevb_groups (parametric_grouped_approxes , three_var_model , aevb_initial ):
177
- dsize = np .prod (three_var_model .one .dshape [1 :])
177
+ dsize = np .prod (opvi . get_transformed ( three_var_model .one ) .dshape [1 :])
178
178
cls , kw = parametric_grouped_approxes
179
179
spec = cls .get_param_spec_for (d = dsize , ** kw )
180
180
params = dict ()
@@ -610,11 +610,11 @@ def test_fit_fn_text(method, kwargs, error, another_simple_model):
610
610
@pytest .fixture ('module' )
611
611
def aevb_model ():
612
612
with pm .Model () as model :
613
- pm .Normal ('x' , shape = (2 ,))
613
+ pm .HalfNormal ('x' , shape = (2 ,))
614
614
pm .Normal ('y' , shape = (2 ,))
615
615
x = model .x
616
616
y = model .y
617
- mu = theano .shared (x .init_value ) * 2
617
+ mu = theano .shared (x .init_value )
618
618
rho = theano .shared (np .zeros_like (x .init_value ))
619
619
return {
620
620
'model' : model ,
@@ -632,8 +632,8 @@ def test_aevb(inference_spec, aevb_model):
632
632
replace = aevb_model ['replace' ]
633
633
with model :
634
634
try :
635
- inference = inference_spec (local_rv = {x : replace })
636
- approx = inference .fit (3 , obj_n_mc = 2 )
635
+ inference = inference_spec (local_rv = {x : { 'mu' : replace [ 'mu' ] * 5 , 'rho' : replace [ 'rho' ]} })
636
+ approx = inference .fit (3 , obj_n_mc = 2 , more_obj_params = replace . values () )
637
637
approx .sample (10 )
638
638
approx .sample_node (
639
639
y ,
0 commit comments