Skip to content

Commit 9168182

Browse files
committed
change replacements apply method
1 parent 85a1f0e commit 9168182

File tree

3 files changed

+43
-40
lines changed

3 files changed

+43
-40
lines changed

pymc3/tests/test_variational_inference.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,17 @@ def test_sample_aevb(three_var_aevb_approx, aevb_initial):
217217
assert trace[0]['three'].shape == (10, 1, 2)
218218

219219

220+
def test_replacements_in_sample_node_aevb(three_var_aevb_approx, aevb_initial):
221+
inp = tt.matrix(dtype='float32')
222+
three_var_aevb_approx.sample_node(
223+
three_var_aevb_approx.model.one, 2,
224+
more_replacements={aevb_initial: inp}).eval({inp: np.random.rand(7, 7).astype('float32')})
225+
226+
three_var_aevb_approx.sample_node(
227+
three_var_aevb_approx.model.one, None,
228+
more_replacements={aevb_initial: inp}).eval({inp: np.random.rand(7, 7).astype('float32')})
229+
230+
220231
def test_logq_mini_1_sample_1_var(parametric_grouped_approxes, three_var_model):
221232
cls, kw = parametric_grouped_approxes
222233
approx = cls([three_var_model.one], model=three_var_model, **kw)

pymc3/variational/operators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __call__(self, nmc, **kwargs):
5555
params = self.test_params + kwargs['more_tf_params']
5656
grad *= pm.floatX(-1)
5757
grads = tt.grad(None, params, known_grads={z: grad})
58-
return self.approx.set_size_and_deterministic(grads, nmc, 0)
58+
return self.approx.set_size_and_deterministic(grads, nmc, 0, kwargs.get('more_replacements'))
5959

6060

6161
class KSD(Operator):

pymc3/variational/opvi.py

Lines changed: 31 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,7 @@ def add_test_updates(self, updates, tf_n_mc=None, test_optimizer=adagrad_window,
210210
more_tf_params = []
211211
if more_replacements is None:
212212
more_replacements = dict()
213-
tf_target = self(tf_n_mc, more_tf_params=more_tf_params)
214-
tf_target = theano.clone(tf_target, more_replacements, strict=False)
213+
tf_target = self(tf_n_mc, more_tf_params=more_tf_params, more_replacements=more_replacements)
215214
grads = pm.updates.get_or_compute_grads(tf_target, self.obj_params + more_tf_params)
216215
if total_grad_norm_constraint is not None:
217216
grads = pm.total_norm_constraint(grads, total_grad_norm_constraint)
@@ -228,8 +227,7 @@ def add_obj_updates(self, updates, obj_n_mc=None, obj_optimizer=adagrad_window,
228227
more_obj_params = []
229228
if more_replacements is None:
230229
more_replacements = dict()
231-
obj_target = self(obj_n_mc, more_obj_params=more_obj_params)
232-
obj_target = theano.clone(obj_target, more_replacements, strict=False)
230+
obj_target = self(obj_n_mc, more_obj_params=more_obj_params, more_replacements=more_replacements)
233231
grads = pm.updates.get_or_compute_grads(obj_target, self.obj_params + more_obj_params)
234232
if total_grad_norm_constraint is not None:
235233
grads = pm.total_norm_constraint(grads, total_grad_norm_constraint)
@@ -329,10 +327,7 @@ def score_function(self, sc_n_mc=None, more_replacements=None, fn_kwargs=None):
329327
raise NotImplementedError('%s does not have loss' % self.op)
330328
if more_replacements is None:
331329
more_replacements = {}
332-
loss = theano.clone(
333-
self(sc_n_mc),
334-
more_replacements,
335-
strict=False)
330+
loss = self(sc_n_mc, more_replacements=more_replacements)
336331
return theano.function([], loss, **fn_kwargs)
337332

338333
@change_flags(compute_test_value='off')
@@ -342,7 +337,7 @@ def __call__(self, nmc, **kwargs):
342337
else:
343338
m = 1.
344339
a = self.op.apply(self.tf)
345-
a = self.approx.set_size_and_deterministic(a, nmc, 0)
340+
a = self.approx.set_size_and_deterministic(a, nmc, 0, kwargs.get('more_replacements'))
346341
return m * self.op.T(a)
347342

348343

@@ -954,21 +949,19 @@ def symbolic_random2d(self):
954949
return self.symbolic_random
955950

956951
@change_flags(compute_test_value='off')
957-
def set_size_and_deterministic(self, node, s, d):
958-
flat2rand = self.make_size_and_deterministic_replacements(s, d)
952+
def set_size_and_deterministic(self, node, s, d, more_replacements=None):
953+
flat2rand = self.make_size_and_deterministic_replacements(s, d, more_replacements)
959954
node_out = theano.clone(node, flat2rand)
960955
try_to_set_test_value(node, node_out, s)
961956
return node_out
962957

963-
def to_flat_input(self, node, more_replacements=None):
958+
def to_flat_input(self, node):
964959
"""Replace vars with flattened view stored in self.inputs
965960
"""
966-
if more_replacements:
967-
node = theano.clone(node, more_replacements)
968961
return theano.clone(node, self.replacements, strict=False)
969962

970-
def symbolic_sample_over_posterior(self, node, more_replacements=None):
971-
node = self.to_flat_input(node, more_replacements)
963+
def symbolic_sample_over_posterior(self, node):
964+
node = self.to_flat_input(node)
972965

973966
def sample(post):
974967
return theano.clone(node, {self.input: post})
@@ -977,24 +970,23 @@ def sample(post):
977970
sample, self.symbolic_random)
978971
return nodes
979972

980-
def symbolic_single_sample(self, node, more_replacements=None):
981-
node = self.to_flat_input(node, more_replacements)
973+
def symbolic_single_sample(self, node):
974+
node = self.to_flat_input(node)
982975
return theano.clone(
983976
node, {self.input: self.symbolic_random[0]}
984977
)
985978

986-
def make_size_and_deterministic_replacements(self, s, d):
987-
initial_ = self._new_initial(s, d)
988-
return collections.OrderedDict({
989-
self.symbolic_initial: initial_
990-
})
979+
def make_size_and_deterministic_replacements(self, s, d, more_replacements=None):
980+
initial = self._new_initial(s, d)
981+
if more_replacements:
982+
initial = theano.clone(initial, more_replacements)
983+
return {self.symbolic_initial: initial}
991984

992985
@node_property
993986
def symbolic_normalizing_constant(self):
994987
t = self.to_flat_input(
995988
tt.max([v.scaling for v in self.group]))
996989
t = self.symbolic_single_sample(t)
997-
t = self.set_size_and_deterministic(t, 1, 1) # remove random, we do not it here at all
998990
return pm.floatX(t)
999991

1000992
@node_property
@@ -1007,7 +999,6 @@ def symbolic_logq(self):
1007999
s = self.group[0].scaling
10081000
s = self.to_flat_input(s)
10091001
s = self.symbolic_single_sample(s)
1010-
s = self.set_size_and_deterministic(s, 1, 1)
10111002
return self.symbolic_logq_not_scaled * s
10121003
else:
10131004
return self.symbolic_logq_not_scaled
@@ -1173,31 +1164,32 @@ def replacements(self):
11731164
g.replacements.items() for g in self.groups
11741165
))
11751166

1176-
def make_size_and_deterministic_replacements(self, s, d):
1167+
def make_size_and_deterministic_replacements(self, s, d, more_replacements=None):
1168+
if more_replacements is None:
1169+
more_replacements = {}
11771170
flat2rand = collections.OrderedDict()
1171+
flat2rand.update(more_replacements)
11781172
for g in self.groups:
1179-
flat2rand.update(g.make_size_and_deterministic_replacements(s, d))
1173+
flat2rand.update(g.make_size_and_deterministic_replacements(s, d, more_replacements))
11801174
return flat2rand
11811175

11821176
@change_flags(compute_test_value='off')
1183-
def set_size_and_deterministic(self, node, s, d):
1177+
def set_size_and_deterministic(self, node, s, d, more_replacements=None):
11841178
optimizations = self.get_optimization_replacements(s, d)
11851179
node = theano.clone(node, optimizations)
1186-
flat2rand = self.make_size_and_deterministic_replacements(s, d)
1180+
flat2rand = self.make_size_and_deterministic_replacements(s, d, more_replacements)
11871181
node_out = theano.clone(node, flat2rand)
11881182
try_to_set_test_value(node, node_out, s)
11891183
return node_out
11901184

1191-
def to_flat_input(self, node, more_replacements=None):
1185+
def to_flat_input(self, node):
11921186
"""
11931187
Replaces vars with flattened view stored in self.inputs
11941188
"""
1195-
if more_replacements:
1196-
node = theano.clone(node, more_replacements)
11971189
return theano.clone(node, self.replacements, strict=False)
11981190

1199-
def symbolic_sample_over_posterior(self, node, more_replacements=None):
1200-
node = self.to_flat_input(node, more_replacements)
1191+
def symbolic_sample_over_posterior(self, node):
1192+
node = self.to_flat_input(node)
12011193

12021194
def sample(*post):
12031195
return theano.clone(node, dict(zip(self.inputs, post)))
@@ -1206,8 +1198,8 @@ def sample(*post):
12061198
sample, self.symbolic_randoms)
12071199
return nodes
12081200

1209-
def symbolic_single_sample(self, node, more_replacements=None):
1210-
node = self.to_flat_input(node, more_replacements)
1201+
def symbolic_single_sample(self, node):
1202+
node = self.to_flat_input(node)
12111203
post = [v[0] for v in self.symbolic_randoms]
12121204
inp = self.inputs
12131205
return theano.clone(
@@ -1244,10 +1236,10 @@ def sample_node(self, node, size=None,
12441236
"""
12451237
node_in = node
12461238
if size is None:
1247-
node_out = self.symbolic_single_sample(node, more_replacements)
1239+
node_out = self.symbolic_single_sample(node)
12481240
else:
1249-
node_out = self.symbolic_sample_over_posterior(node, more_replacements)
1250-
node_out = self.set_size_and_deterministic(node_out, size, deterministic)
1241+
node_out = self.symbolic_sample_over_posterior(node)
1242+
node_out = self.set_size_and_deterministic(node_out, size, deterministic, more_replacements)
12511243
try_to_set_test_value(node_in, node_out, size)
12521244
return node_out
12531245

0 commit comments

Comments
 (0)