@@ -210,8 +210,7 @@ def add_test_updates(self, updates, tf_n_mc=None, test_optimizer=adagrad_window,
210
210
more_tf_params = []
211
211
if more_replacements is None :
212
212
more_replacements = dict ()
213
- tf_target = self (tf_n_mc , more_tf_params = more_tf_params )
214
- tf_target = theano .clone (tf_target , more_replacements , strict = False )
213
+ tf_target = self (tf_n_mc , more_tf_params = more_tf_params , more_replacements = more_replacements )
215
214
grads = pm .updates .get_or_compute_grads (tf_target , self .obj_params + more_tf_params )
216
215
if total_grad_norm_constraint is not None :
217
216
grads = pm .total_norm_constraint (grads , total_grad_norm_constraint )
@@ -228,8 +227,7 @@ def add_obj_updates(self, updates, obj_n_mc=None, obj_optimizer=adagrad_window,
228
227
more_obj_params = []
229
228
if more_replacements is None :
230
229
more_replacements = dict ()
231
- obj_target = self (obj_n_mc , more_obj_params = more_obj_params )
232
- obj_target = theano .clone (obj_target , more_replacements , strict = False )
230
+ obj_target = self (obj_n_mc , more_obj_params = more_obj_params , more_replacements = more_replacements )
233
231
grads = pm .updates .get_or_compute_grads (obj_target , self .obj_params + more_obj_params )
234
232
if total_grad_norm_constraint is not None :
235
233
grads = pm .total_norm_constraint (grads , total_grad_norm_constraint )
@@ -329,10 +327,7 @@ def score_function(self, sc_n_mc=None, more_replacements=None, fn_kwargs=None):
329
327
raise NotImplementedError ('%s does not have loss' % self .op )
330
328
if more_replacements is None :
331
329
more_replacements = {}
332
- loss = theano .clone (
333
- self (sc_n_mc ),
334
- more_replacements ,
335
- strict = False )
330
+ loss = self (sc_n_mc , more_replacements = more_replacements )
336
331
return theano .function ([], loss , ** fn_kwargs )
337
332
338
333
@change_flags (compute_test_value = 'off' )
@@ -342,7 +337,7 @@ def __call__(self, nmc, **kwargs):
342
337
else :
343
338
m = 1.
344
339
a = self .op .apply (self .tf )
345
- a = self .approx .set_size_and_deterministic (a , nmc , 0 )
340
+ a = self .approx .set_size_and_deterministic (a , nmc , 0 , kwargs . get ( 'more_replacements' ) )
346
341
return m * self .op .T (a )
347
342
348
343
@@ -954,21 +949,19 @@ def symbolic_random2d(self):
954
949
return self .symbolic_random
955
950
956
951
@change_flags (compute_test_value = 'off' )
957
- def set_size_and_deterministic (self , node , s , d ):
958
- flat2rand = self .make_size_and_deterministic_replacements (s , d )
952
+ def set_size_and_deterministic (self , node , s , d , more_replacements = None ):
953
+ flat2rand = self .make_size_and_deterministic_replacements (s , d , more_replacements )
959
954
node_out = theano .clone (node , flat2rand )
960
955
try_to_set_test_value (node , node_out , s )
961
956
return node_out
962
957
963
- def to_flat_input (self , node , more_replacements = None ):
958
+ def to_flat_input (self , node ):
964
959
"""Replace vars with flattened view stored in self.inputs
965
960
"""
966
- if more_replacements :
967
- node = theano .clone (node , more_replacements )
968
961
return theano .clone (node , self .replacements , strict = False )
969
962
970
- def symbolic_sample_over_posterior (self , node , more_replacements = None ):
971
- node = self .to_flat_input (node , more_replacements )
963
+ def symbolic_sample_over_posterior (self , node ):
964
+ node = self .to_flat_input (node )
972
965
973
966
def sample (post ):
974
967
return theano .clone (node , {self .input : post })
@@ -977,24 +970,23 @@ def sample(post):
977
970
sample , self .symbolic_random )
978
971
return nodes
979
972
980
- def symbolic_single_sample (self , node , more_replacements = None ):
981
- node = self .to_flat_input (node , more_replacements )
973
+ def symbolic_single_sample (self , node ):
974
+ node = self .to_flat_input (node )
982
975
return theano .clone (
983
976
node , {self .input : self .symbolic_random [0 ]}
984
977
)
985
978
986
- def make_size_and_deterministic_replacements (self , s , d ):
987
- initial_ = self ._new_initial (s , d )
988
- return collections . OrderedDict ({
989
- self . symbolic_initial : initial_
990
- })
979
+ def make_size_and_deterministic_replacements (self , s , d , more_replacements = None ):
980
+ initial = self ._new_initial (s , d )
981
+ if more_replacements :
982
+ initial = theano . clone ( initial , more_replacements )
983
+ return { self . symbolic_initial : initial }
991
984
992
985
@node_property
993
986
def symbolic_normalizing_constant (self ):
994
987
t = self .to_flat_input (
995
988
tt .max ([v .scaling for v in self .group ]))
996
989
t = self .symbolic_single_sample (t )
997
- t = self .set_size_and_deterministic (t , 1 , 1 ) # remove random, we do not it here at all
998
990
return pm .floatX (t )
999
991
1000
992
@node_property
@@ -1007,7 +999,6 @@ def symbolic_logq(self):
1007
999
s = self .group [0 ].scaling
1008
1000
s = self .to_flat_input (s )
1009
1001
s = self .symbolic_single_sample (s )
1010
- s = self .set_size_and_deterministic (s , 1 , 1 )
1011
1002
return self .symbolic_logq_not_scaled * s
1012
1003
else :
1013
1004
return self .symbolic_logq_not_scaled
@@ -1173,31 +1164,32 @@ def replacements(self):
1173
1164
g .replacements .items () for g in self .groups
1174
1165
))
1175
1166
1176
- def make_size_and_deterministic_replacements (self , s , d ):
1167
+ def make_size_and_deterministic_replacements (self , s , d , more_replacements = None ):
1168
+ if more_replacements is None :
1169
+ more_replacements = {}
1177
1170
flat2rand = collections .OrderedDict ()
1171
+ flat2rand .update (more_replacements )
1178
1172
for g in self .groups :
1179
- flat2rand .update (g .make_size_and_deterministic_replacements (s , d ))
1173
+ flat2rand .update (g .make_size_and_deterministic_replacements (s , d , more_replacements ))
1180
1174
return flat2rand
1181
1175
1182
1176
@change_flags (compute_test_value = 'off' )
1183
- def set_size_and_deterministic (self , node , s , d ):
1177
+ def set_size_and_deterministic (self , node , s , d , more_replacements = None ):
1184
1178
optimizations = self .get_optimization_replacements (s , d )
1185
1179
node = theano .clone (node , optimizations )
1186
- flat2rand = self .make_size_and_deterministic_replacements (s , d )
1180
+ flat2rand = self .make_size_and_deterministic_replacements (s , d , more_replacements )
1187
1181
node_out = theano .clone (node , flat2rand )
1188
1182
try_to_set_test_value (node , node_out , s )
1189
1183
return node_out
1190
1184
1191
- def to_flat_input (self , node , more_replacements = None ):
1185
+ def to_flat_input (self , node ):
1192
1186
"""
1193
1187
Replaces vars with flattened view stored in self.inputs
1194
1188
"""
1195
- if more_replacements :
1196
- node = theano .clone (node , more_replacements )
1197
1189
return theano .clone (node , self .replacements , strict = False )
1198
1190
1199
- def symbolic_sample_over_posterior (self , node , more_replacements = None ):
1200
- node = self .to_flat_input (node , more_replacements )
1191
+ def symbolic_sample_over_posterior (self , node ):
1192
+ node = self .to_flat_input (node )
1201
1193
1202
1194
def sample (* post ):
1203
1195
return theano .clone (node , dict (zip (self .inputs , post )))
@@ -1206,8 +1198,8 @@ def sample(*post):
1206
1198
sample , self .symbolic_randoms )
1207
1199
return nodes
1208
1200
1209
- def symbolic_single_sample (self , node , more_replacements = None ):
1210
- node = self .to_flat_input (node , more_replacements )
1201
+ def symbolic_single_sample (self , node ):
1202
+ node = self .to_flat_input (node )
1211
1203
post = [v [0 ] for v in self .symbolic_randoms ]
1212
1204
inp = self .inputs
1213
1205
return theano .clone (
@@ -1244,10 +1236,10 @@ def sample_node(self, node, size=None,
1244
1236
"""
1245
1237
node_in = node
1246
1238
if size is None :
1247
- node_out = self .symbolic_single_sample (node , more_replacements )
1239
+ node_out = self .symbolic_single_sample (node )
1248
1240
else :
1249
- node_out = self .symbolic_sample_over_posterior (node , more_replacements )
1250
- node_out = self .set_size_and_deterministic (node_out , size , deterministic )
1241
+ node_out = self .symbolic_sample_over_posterior (node )
1242
+ node_out = self .set_size_and_deterministic (node_out , size , deterministic , more_replacements )
1251
1243
try_to_set_test_value (node_in , node_out , size )
1252
1244
return node_out
1253
1245
0 commit comments