@@ -92,7 +92,8 @@ def embedding_net(xx,
9292 bavg = 0.0 ,
9393 seed = None ,
9494 trainable = True ,
95- uniform_seed = False ):
95+ uniform_seed = False ,
96+ initial_variables = None ):
9697 r"""The embedding network.
9798
9899 The embedding network function :math:`\mathcal{N}` is constructed by is the
@@ -141,6 +142,11 @@ def embedding_net(xx,
141142 Random seed for initializing network parameters
142143 trainable: boolean
143144 If the network is trainable
145+ uniform_seed : boolean
146+ Only for the purpose of backward compatibility, retrieves the old behavior of using the random seed
147+ initial_variables : dict
148+ The input dict which stores the embedding net variables
149+
144150
145151 References
146152 ----------
@@ -152,37 +158,47 @@ def embedding_net(xx,
152158 outputs_size = [input_shape [1 ]] + network_size
153159
154160 for ii in range (1 , len (outputs_size )):
155- w = tf .get_variable ('matrix_' + str (ii )+ name_suffix ,
161+ w_initializer = tf .random_normal_initializer (
162+ stddev = stddev / np .sqrt (outputs_size [ii ]+ outputs_size [ii - 1 ]),
163+ seed = seed if (seed is None or uniform_seed ) else seed + ii * 3 + 0
164+ )
165+ b_initializer = tf .random_normal_initializer (
166+ stddev = stddev ,
167+ mean = bavg ,
168+ seed = seed if (seed is None or uniform_seed ) else seed + 3 * ii + 1
169+ )
170+ if initial_variables is not None :
171+ scope = tf .get_variable_scope ().name
172+ w_initializer = tf .constant_initializer (initial_variables [scope + '/matrix_' + str (ii )+ name_suffix ])
173+ b_initializer = tf .constant_initializer (initial_variables [scope + '/bias_' + str (ii )+ name_suffix ])
174+ w = tf .get_variable ('matrix_' + str (ii )+ name_suffix ,
156175 [outputs_size [ii - 1 ], outputs_size [ii ]],
157176 precision ,
158- tf .random_normal_initializer (
159- stddev = stddev / np .sqrt (outputs_size [ii ]+ outputs_size [ii - 1 ]),
160- seed = seed if (seed is None or uniform_seed ) else seed + ii * 3 + 0
161- ),
177+ w_initializer ,
162178 trainable = trainable )
163179 variable_summaries (w , 'matrix_' + str (ii )+ name_suffix )
164180
165181 b = tf .get_variable ('bias_' + str (ii )+ name_suffix ,
166182 [1 , outputs_size [ii ]],
167183 precision ,
168- tf .random_normal_initializer (
169- stddev = stddev ,
170- mean = bavg ,
171- seed = seed if (seed is None or uniform_seed ) else seed + 3 * ii + 1
172- ),
184+ b_initializer ,
173185 trainable = trainable )
174186 variable_summaries (b , 'bias_' + str (ii )+ name_suffix )
175187
176188 hidden = tf .reshape (activation_fn (tf .matmul (xx , w ) + b ), [- 1 , outputs_size [ii ]])
177189 if resnet_dt :
190+ idt_initializer = tf .random_normal_initializer (
191+ stddev = 0.001 ,
192+ mean = 1.0 ,
193+ seed = seed if (seed is None or uniform_seed ) else seed + 3 * ii + 2
194+ )
195+ if initial_variables is not None :
196+ scope = tf .get_variable_scope ().name
197+ idt_initializer = tf .constant_initializer (initial_variables [scope + '/idt_' + str (ii )+ name_suffix ])
178198 idt = tf .get_variable ('idt_' + str (ii )+ name_suffix ,
179199 [1 , outputs_size [ii ]],
180200 precision ,
181- tf .random_normal_initializer (
182- stddev = 0.001 ,
183- mean = 1.0 ,
184- seed = seed if (seed is None or uniform_seed ) else seed + 3 * ii + 2
185- ),
201+ idt_initializer ,
186202 trainable = trainable )
187203 variable_summaries (idt , 'idt_' + str (ii )+ name_suffix )
188204
0 commit comments