@@ -884,9 +884,11 @@ def params(self):
884
884
else :
885
885
return collect_shared_to_list (self .shared_params )
886
886
887
- def _new_initial_shape (self , size , dim ):
887
+ def _new_initial_shape (self , size , dim , more_replacements ):
888
888
if self .batched :
889
- return tt .stack ([size , self .bdim , dim ])
889
+ bdim = tt .as_tensor (self .bdim )
890
+ bdim = theano .clone (bdim , more_replacements )
891
+ return tt .stack ([size , bdim , dim ])
890
892
else :
891
893
return tt .stack ([size , dim ])
892
894
@@ -908,7 +910,7 @@ def ndim(self):
908
910
def ddim (self ):
909
911
return self .ordering .size
910
912
911
- def _new_initial (self , size , deterministic ):
913
+ def _new_initial (self , size , deterministic , more_replacements = None ):
912
914
if size is None :
913
915
size = 1
914
916
if not isinstance (deterministic , tt .Variable ):
@@ -921,7 +923,7 @@ def _new_initial(self, size, deterministic):
921
923
dtype = self .symbolic_initial .dtype
922
924
dim = tt .as_tensor (dim )
923
925
size = tt .as_tensor (size )
924
- shape = self ._new_initial_shape (size , dim )
926
+ shape = self ._new_initial_shape (size , dim , more_replacements )
925
927
# apply optimizations if possible
926
928
if not isinstance (deterministic , tt .Variable ):
927
929
if deterministic :
@@ -951,33 +953,38 @@ def symbolic_random2d(self):
951
953
@change_flags (compute_test_value = 'off' )
952
954
def set_size_and_deterministic (self , node , s , d , more_replacements = None ):
953
955
flat2rand = self .make_size_and_deterministic_replacements (s , d , more_replacements )
954
- node_out = theano .clone (node , flat2rand )
956
+ node_out = theano .clone (node , flat2rand , strict = False )
955
957
try_to_set_test_value (node , node_out , s )
956
958
return node_out
957
959
958
960
def to_flat_input (self , node ):
959
961
"""Replace vars with flattened view stored in self.inputs
960
962
"""
961
- return theano .clone (node , self .replacements )
963
+ return theano .clone (node , self .replacements , strict = False )
962
964
963
965
def symbolic_sample_over_posterior (self , node ):
964
966
node = self .to_flat_input (node )
967
+ random = self .symbolic_random .astype (self .symbolic_initial .dtype )
968
+ random = tt .patternbroadcast (random , self .symbolic_initial .broadcastable )
965
969
966
970
def sample (post ):
967
971
return theano .clone (node , {self .input : post })
968
972
969
973
nodes , _ = theano .scan (
970
- sample , self . symbolic_random )
974
+ sample , random )
971
975
return nodes
972
976
973
977
def symbolic_single_sample (self , node ):
974
978
node = self .to_flat_input (node )
979
+ random = self .symbolic_random .astype (self .symbolic_initial .dtype )
980
+ random = tt .patternbroadcast (random , self .symbolic_initial .broadcastable )
975
981
return theano .clone (
976
- node , {self .input : self . symbolic_random [0 ]}
982
+ node , {self .input : random [0 ]}
977
983
)
978
984
979
985
def make_size_and_deterministic_replacements (self , s , d , more_replacements = None ):
980
- initial = self ._new_initial (s , d )
986
+ initial = self ._new_initial (s , d , more_replacements )
987
+ initial = tt .patternbroadcast (initial , self .symbolic_initial .broadcastable )
981
988
if more_replacements :
982
989
initial = theano .clone (initial , more_replacements )
983
990
return {self .symbolic_initial : initial }
@@ -1072,6 +1079,7 @@ def __init__(self, groups, model=None):
1072
1079
model = modelcontext (model )
1073
1080
if not model .free_RVs :
1074
1081
raise TypeError ('Model does not have FreeRVs' )
1082
+ self .groups = list ()
1075
1083
seen = set ()
1076
1084
rest = None
1077
1085
for g in groups :
@@ -1085,12 +1093,13 @@ def __init__(self, groups, model=None):
1085
1093
if set (g .group ) & seen :
1086
1094
raise GroupError ('Found duplicates in groups' )
1087
1095
seen .update (g .group )
1096
+ self .groups .append (g )
1088
1097
if set (model .free_RVs ) - seen :
1089
1098
if rest is None :
1090
1099
raise GroupError ('No approximation is specified for the rest variables' )
1091
1100
else :
1092
1101
rest .__init_group__ (list (set (model .free_RVs ) - seen ))
1093
- self .groups = groups
1102
+ self .groups . append ( rest )
1094
1103
self .model = model
1095
1104
1096
1105
@property
@@ -1168,17 +1177,17 @@ def make_size_and_deterministic_replacements(self, s, d, more_replacements=None)
1168
1177
if more_replacements is None :
1169
1178
more_replacements = {}
1170
1179
flat2rand = collections .OrderedDict ()
1171
- flat2rand .update (more_replacements )
1172
1180
for g in self .groups :
1173
1181
flat2rand .update (g .make_size_and_deterministic_replacements (s , d , more_replacements ))
1182
+ flat2rand .update (more_replacements )
1174
1183
return flat2rand
1175
1184
1176
1185
@change_flags (compute_test_value = 'off' )
1177
1186
def set_size_and_deterministic (self , node , s , d , more_replacements = None ):
1178
1187
optimizations = self .get_optimization_replacements (s , d )
1179
- node = theano .clone (node , optimizations )
1180
1188
flat2rand = self .make_size_and_deterministic_replacements (s , d , more_replacements )
1181
- node_out = theano .clone (node , flat2rand )
1189
+ node = theano .clone (node , optimizations )
1190
+ node_out = theano .clone (node , flat2rand , strict = False )
1182
1191
try_to_set_test_value (node , node_out , s )
1183
1192
return node_out
1184
1193
0 commit comments