Skip to content

Commit da14207

Browse files
committed
fix shape issue
1 parent ab95093 commit da14207

File tree

3 files changed

+17
-7
lines changed

3 files changed

+17
-7
lines changed

pymc3/tests/test_variational_inference.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def test_sample_simple(three_var_approx):
150150
assert trace[0]['three'].shape == (10, 1, 2)
151151

152152

153-
@pytest.fixture('module')
153+
@pytest.fixture
154154
def aevb_initial():
155155
return theano.shared(np.random.rand(3, 7).astype('float32'))
156156

@@ -196,17 +196,20 @@ def three_var_aevb_approx(three_var_model, three_var_aevb_groups):
196196

197197

198198
def test_sample_aevb(three_var_aevb_approx, aevb_initial):
199+
pm.KLqp(three_var_aevb_approx).fit(1, more_replacements={
200+
aevb_initial: np.zeros_like(aevb_initial.get_value())[:1]
201+
})
199202
aevb_initial.set_value(np.random.rand(7, 7).astype('float32'))
200203
trace = three_var_aevb_approx.sample(500)
201-
assert set(trace.varnames) == {'one', 'two', 'three'}
204+
assert set(trace.varnames) == {'one', 'one_log__', 'two', 'three'}
202205
assert len(trace) == 500
203206
assert trace[0]['one'].shape == (7, 2)
204207
assert trace[0]['two'].shape == (10, )
205208
assert trace[0]['three'].shape == (10, 1, 2)
206209

207210
aevb_initial.set_value(np.random.rand(13, 7).astype('float32'))
208211
trace = three_var_aevb_approx.sample(500)
209-
assert set(trace.varnames) == {'one', 'two', 'three'}
212+
assert set(trace.varnames) == {'one', 'one_log__', 'two', 'three'}
210213
assert len(trace) == 500
211214
assert trace[0]['one'].shape == (13, 2)
212215
assert trace[0]['two'].shape == (10,)
@@ -610,7 +613,7 @@ def test_fit_fn_text(method, kwargs, error, another_simple_model):
610613
@pytest.fixture('module')
611614
def aevb_model():
612615
with pm.Model() as model:
613-
pm.HalfNormal('x', shape=(2,))
616+
pm.HalfNormal('x', shape=(2,), total_size=5)
614617
pm.Normal('y', shape=(2,))
615618
x = model.x
616619
y = model.y
@@ -633,7 +636,7 @@ def test_aevb(inference_spec, aevb_model):
633636
with model:
634637
try:
635638
inference = inference_spec(local_rv={x: {'mu': replace['mu']*5, 'rho': replace['rho']}})
636-
approx = inference.fit(3, obj_n_mc=2, more_obj_params=replace.values())
639+
approx = inference.fit(3, obj_n_mc=2, more_obj_params=list(replace.values()))
637640
approx.sample(10)
638641
approx.sample_node(
639642
y,

pymc3/variational/approximations.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def cov(self):
5454
def std(self):
5555
return rho2sd(self.rho)
5656

57+
@change_flags(compute_test_value='off')
5758
def __init_group__(self, group):
5859
super(MeanFieldGroup, self).__init_group__(group)
5960
if not self._check_user_params():
@@ -111,6 +112,7 @@ class FullRankGroup(Group):
111112
short_name = 'full_rank'
112113
alias_names = frozenset(['fr'])
113114

115+
@change_flags(compute_test_value='off')
114116
def __init_group__(self, group):
115117
super(FullRankGroup, self).__init_group__(group)
116118
if not self._check_user_params():
@@ -224,6 +226,7 @@ class EmpiricalGroup(Group):
224226
__param_spec__ = dict(histogram=('s', 'd'))
225227
short_name = 'empirical'
226228

229+
@change_flags(compute_test_value='off')
227230
def __init_group__(self, group):
228231
super(EmpiricalGroup, self).__init_group__(group)
229232
self._check_trace()

pymc3/variational/opvi.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -838,8 +838,8 @@ def __init_group__(self, group):
838838
# so I have to to it by myself
839839
self.ordering = ArrayOrdering([])
840840
self.replacements = dict()
841+
self.group = [get_transformed(var) for var in self.group]
841842
for var in self.group:
842-
var = get_transformed(var)
843843
begin = self.ddim
844844
if self.batched:
845845
if var.ndim < 1:
@@ -1003,7 +1003,11 @@ def symbolic_logq_not_scaled(self):
10031003
@node_property
10041004
def symbolic_logq(self):
10051005
if self.islocal:
1006-
return self.symbolic_logq_not_scaled * self.group[0].scaling
1006+
s = self.group[0].scaling
1007+
s = self.to_flat_input(s)
1008+
s = self.symbolic_single_sample(s)
1009+
s = self.set_size_and_deterministic(s, 1, 1)
1010+
return self.symbolic_logq_not_scaled * s
10071011
else:
10081012
return self.symbolic_logq_not_scaled
10091013

0 commit comments

Comments
 (0)