Skip to content

Commit c2cf5d3

Browse files
committed
found and fixed a bug with aevb
1 parent 85af2f6 commit c2cf5d3

File tree

2 files changed

+50
-14
lines changed

2 files changed

+50
-14
lines changed

pymc3/tests/test_variational_inference.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def test_tracker_callback():
8080
@pytest.fixture('module')
8181
def three_var_model():
8282
with pm.Model() as model:
83-
pm.HalfNormal('one', shape=(10, 2))
83+
pm.HalfNormal('one', shape=(10, 2), total_size=100)
8484
pm.Normal('two', shape=(10, ))
8585
pm.Normal('three', shape=(10, 1, 2))
8686
return model
@@ -228,6 +228,33 @@ def test_replacements_in_sample_node_aevb(three_var_aevb_approx, aevb_initial):
228228
more_replacements={aevb_initial: inp}).eval({inp: np.random.rand(7, 7).astype('float32')})
229229

230230

231+
def test_vae():
232+
minibatch_size = 10
233+
data = np.random.rand(100).astype('float32')
234+
x_mini = pm.Minibatch(data, minibatch_size)
235+
x_inp = tt.vector()
236+
x_inp.tag.test_value = data[:minibatch_size]
237+
238+
ae = theano.shared(np.asarray([.1, .1], 'float32'))
239+
be = theano.shared(np.asarray(1., dtype='float32'))
240+
241+
ad = theano.shared(np.asarray(1., dtype='float32'))
242+
bd = theano.shared(np.asarray(1., dtype='float32'))
243+
244+
enc = x_inp.dimshuffle(0, 'x') * ae.dimshuffle('x', 0) + be
245+
mu, rho = enc[:, 0], enc[:, 1]
246+
247+
with pm.Model():
248+
# Hidden variables
249+
zs = pm.Normal('zs', mu=0, sd=1, shape=minibatch_size, dtype='float32')
250+
dec = zs * ad + bd
251+
# Observation model
252+
pm.Normal('xs_', mu=dec, sd=0.1, observed=x_inp, dtype='float32')
253+
254+
pm.fit(1, local_rv={zs: dict(mu=mu, rho=rho)},
255+
more_replacements={x_inp: x_mini}, more_obj_params=[ae, be, ad, bd])
256+
257+
231258
def test_logq_mini_1_sample_1_var(parametric_grouped_approxes, three_var_model):
232259
cls, kw = parametric_grouped_approxes
233260
approx = cls([three_var_model.one], model=three_var_model, **kw)

pymc3/variational/opvi.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -884,9 +884,11 @@ def params(self):
884884
else:
885885
return collect_shared_to_list(self.shared_params)
886886

887-
def _new_initial_shape(self, size, dim):
887+
def _new_initial_shape(self, size, dim, more_replacements):
888888
if self.batched:
889-
return tt.stack([size, self.bdim, dim])
889+
bdim = tt.as_tensor(self.bdim)
890+
bdim = theano.clone(bdim, more_replacements)
891+
return tt.stack([size, bdim, dim])
890892
else:
891893
return tt.stack([size, dim])
892894

@@ -908,7 +910,7 @@ def ndim(self):
908910
def ddim(self):
909911
return self.ordering.size
910912

911-
def _new_initial(self, size, deterministic):
913+
def _new_initial(self, size, deterministic, more_replacements=None):
912914
if size is None:
913915
size = 1
914916
if not isinstance(deterministic, tt.Variable):
@@ -921,7 +923,7 @@ def _new_initial(self, size, deterministic):
921923
dtype = self.symbolic_initial.dtype
922924
dim = tt.as_tensor(dim)
923925
size = tt.as_tensor(size)
924-
shape = self._new_initial_shape(size, dim)
926+
shape = self._new_initial_shape(size, dim, more_replacements)
925927
# apply optimizations if possible
926928
if not isinstance(deterministic, tt.Variable):
927929
if deterministic:
@@ -951,33 +953,38 @@ def symbolic_random2d(self):
951953
@change_flags(compute_test_value='off')
952954
def set_size_and_deterministic(self, node, s, d, more_replacements=None):
953955
flat2rand = self.make_size_and_deterministic_replacements(s, d, more_replacements)
954-
node_out = theano.clone(node, flat2rand)
956+
node_out = theano.clone(node, flat2rand, strict=False)
955957
try_to_set_test_value(node, node_out, s)
956958
return node_out
957959

958960
def to_flat_input(self, node):
959961
"""Replace vars with flattened view stored in self.inputs
960962
"""
961-
return theano.clone(node, self.replacements)
963+
return theano.clone(node, self.replacements, strict=False)
962964

963965
def symbolic_sample_over_posterior(self, node):
964966
node = self.to_flat_input(node)
967+
random = self.symbolic_random.astype(self.symbolic_initial.dtype)
968+
random = tt.patternbroadcast(random, self.symbolic_initial.broadcastable)
965969

966970
def sample(post):
967971
return theano.clone(node, {self.input: post})
968972

969973
nodes, _ = theano.scan(
970-
sample, self.symbolic_random)
974+
sample, random)
971975
return nodes
972976

973977
def symbolic_single_sample(self, node):
974978
node = self.to_flat_input(node)
979+
random = self.symbolic_random.astype(self.symbolic_initial.dtype)
980+
random = tt.patternbroadcast(random, self.symbolic_initial.broadcastable)
975981
return theano.clone(
976-
node, {self.input: self.symbolic_random[0]}
982+
node, {self.input: random[0]}
977983
)
978984

979985
def make_size_and_deterministic_replacements(self, s, d, more_replacements=None):
980-
initial = self._new_initial(s, d)
986+
initial = self._new_initial(s, d, more_replacements)
987+
initial = tt.patternbroadcast(initial, self.symbolic_initial.broadcastable)
981988
if more_replacements:
982989
initial = theano.clone(initial, more_replacements)
983990
return {self.symbolic_initial: initial}
@@ -1072,6 +1079,7 @@ def __init__(self, groups, model=None):
10721079
model = modelcontext(model)
10731080
if not model.free_RVs:
10741081
raise TypeError('Model does not have FreeRVs')
1082+
self.groups = list()
10751083
seen = set()
10761084
rest = None
10771085
for g in groups:
@@ -1085,12 +1093,13 @@ def __init__(self, groups, model=None):
10851093
if set(g.group) & seen:
10861094
raise GroupError('Found duplicates in groups')
10871095
seen.update(g.group)
1096+
self.groups.append(g)
10881097
if set(model.free_RVs) - seen:
10891098
if rest is None:
10901099
raise GroupError('No approximation is specified for the rest variables')
10911100
else:
10921101
rest.__init_group__(list(set(model.free_RVs) - seen))
1093-
self.groups = groups
1102+
self.groups.append(rest)
10941103
self.model = model
10951104

10961105
@property
@@ -1168,17 +1177,17 @@ def make_size_and_deterministic_replacements(self, s, d, more_replacements=None)
11681177
if more_replacements is None:
11691178
more_replacements = {}
11701179
flat2rand = collections.OrderedDict()
1171-
flat2rand.update(more_replacements)
11721180
for g in self.groups:
11731181
flat2rand.update(g.make_size_and_deterministic_replacements(s, d, more_replacements))
1182+
flat2rand.update(more_replacements)
11741183
return flat2rand
11751184

11761185
@change_flags(compute_test_value='off')
11771186
def set_size_and_deterministic(self, node, s, d, more_replacements=None):
11781187
optimizations = self.get_optimization_replacements(s, d)
1179-
node = theano.clone(node, optimizations)
11801188
flat2rand = self.make_size_and_deterministic_replacements(s, d, more_replacements)
1181-
node_out = theano.clone(node, flat2rand)
1189+
node = theano.clone(node, optimizations)
1190+
node_out = theano.clone(node, flat2rand, strict=False)
11821191
try_to_set_test_value(node, node_out, s)
11831192
return node_out
11841193

0 commit comments

Comments
 (0)