@@ -28,6 +28,7 @@ def __init__ (self, jdata):
2828 .add ('neuron' , list , default = [10 , 20 , 40 ]) \
2929 .add ('axis_neuron' , int , default = 4 , alias = 'n_axis_neuron' ) \
3030 .add ('resnet_dt' ,bool , default = False ) \
31+ .add ('trainable' ,bool , default = True ) \
3132 .add ('seed' , int )
3233 class_data = args .parse (jdata )
3334 self .sel_a = class_data ['sel' ]
@@ -37,6 +38,7 @@ def __init__ (self, jdata):
3738 self .n_axis_neuron = class_data ['axis_neuron' ]
3839 self .filter_resnet_dt = class_data ['resnet_dt' ]
3940 self .seed = class_data ['seed' ]
41+ self .trainable = class_data ['trainable' ]
4042
4143 # descrpt config
4244 self .sel_r = [ 0 for ii in range (len (self .sel_a )) ]
@@ -167,7 +169,7 @@ def build (self,
167169
168170 self .descrpt_reshape = tf .reshape (self .descrpt , [- 1 , self .ndescrpt ])
169171
170- self .dout , self .qmat = self ._pass_filter (self .descrpt_reshape , natoms , suffix = suffix , reuse = reuse )
172+ self .dout , self .qmat = self ._pass_filter (self .descrpt_reshape , natoms , suffix = suffix , reuse = reuse , trainable = self . trainable )
171173
172174 return self .dout
173175
@@ -201,7 +203,8 @@ def _pass_filter(self,
201203 inputs ,
202204 natoms ,
203205 reuse = None ,
204- suffix = '' ) :
206+ suffix = '' ,
207+ trainable = True ) :
205208 start_index = 0
206209 inputs = tf .reshape (inputs , [- 1 , self .ndescrpt * natoms [0 ]])
207210 shape = inputs .get_shape ().as_list ()
@@ -212,7 +215,7 @@ def _pass_filter(self,
212215 [ 0 , start_index * self .ndescrpt ],
213216 [- 1 , natoms [2 + type_i ]* self .ndescrpt ] )
214217 inputs_i = tf .reshape (inputs_i , [- 1 , self .ndescrpt ])
215- layer , qmat = self ._filter (inputs_i , name = 'filter_type_' + str (type_i )+ suffix , natoms = natoms , reuse = reuse , seed = self .seed )
218+ layer , qmat = self ._filter (inputs_i , name = 'filter_type_' + str (type_i )+ suffix , natoms = natoms , reuse = reuse , seed = self .seed , trainable = trainable )
216219 layer = tf .reshape (layer , [tf .shape (inputs )[0 ], natoms [2 + type_i ] * self .get_dim_out ()])
217220 qmat = tf .reshape (qmat , [tf .shape (inputs )[0 ], natoms [2 + type_i ] * self .get_dim_rot_mat_1 () * 3 ])
218221 output .append (layer )
@@ -297,7 +300,8 @@ def _filter(self,
297300 bavg = 0.0 ,
298301 name = 'linear' ,
299302 reuse = None ,
300- seed = None ):
303+ seed = None ,
304+ trainable = True ):
301305 # natom x (nei x 4)
302306 shape = inputs .get_shape ().as_list ()
303307 outputs_size = [1 ] + self .filter_neuron
@@ -320,16 +324,19 @@ def _filter(self,
320324 w = tf .get_variable ('matrix_' + str (ii )+ '_' + str (type_i ),
321325 [outputs_size [ii - 1 ], outputs_size [ii ]],
322326 global_tf_float_precision ,
323- tf .random_normal_initializer (stddev = stddev / np .sqrt (outputs_size [ii ]+ outputs_size [ii - 1 ]), seed = seed ))
327+ tf .random_normal_initializer (stddev = stddev / np .sqrt (outputs_size [ii ]+ outputs_size [ii - 1 ]), seed = seed ),
328+ trainable = trainable )
324329 b = tf .get_variable ('bias_' + str (ii )+ '_' + str (type_i ),
325330 [1 , outputs_size [ii ]],
326331 global_tf_float_precision ,
327- tf .random_normal_initializer (stddev = stddev , mean = bavg , seed = seed ))
332+ tf .random_normal_initializer (stddev = stddev , mean = bavg , seed = seed ),
333+ trainable = trainable )
328334 if self .filter_resnet_dt :
329335 idt = tf .get_variable ('idt_' + str (ii )+ '_' + str (type_i ),
330336 [1 , outputs_size [ii ]],
331337 global_tf_float_precision ,
332- tf .random_normal_initializer (stddev = 0.001 , mean = 1.0 , seed = seed ))
338+ tf .random_normal_initializer (stddev = 0.001 , mean = 1.0 , seed = seed ),
339+ trainable = trainable )
333340 if outputs_size [ii ] == outputs_size [ii - 1 ]:
334341 if self .filter_resnet_dt :
335342 xyz_scatter += activation_fn (tf .matmul (xyz_scatter , w ) + b ) * idt
@@ -376,7 +383,8 @@ def _filter_type_ext(self,
376383 bavg = 0.0 ,
377384 name = 'linear' ,
378385 reuse = None ,
379- seed = None ):
386+ seed = None ,
387+ trainable = True ):
380388 # natom x (nei x 4)
381389 shape = inputs .get_shape ().as_list ()
382390 outputs_size = [1 ] + self .filter_neuron
@@ -401,16 +409,19 @@ def _filter_type_ext(self,
401409 w = tf .get_variable ('matrix_' + str (ii )+ '_' + str (type_i ),
402410 [outputs_size [ii - 1 ], outputs_size [ii ]],
403411 global_tf_float_precision ,
404- tf .random_normal_initializer (stddev = stddev / np .sqrt (outputs_size [ii ]+ outputs_size [ii - 1 ]), seed = seed ))
412+ tf .random_normal_initializer (stddev = stddev / np .sqrt (outputs_size [ii ]+ outputs_size [ii - 1 ]), seed = seed ),
413+ trainable = trainable )
405414 b = tf .get_variable ('bias_' + str (ii )+ '_' + str (type_i ),
406415 [1 , outputs_size [ii ]],
407416 global_tf_float_precision ,
408- tf .random_normal_initializer (stddev = stddev , mean = bavg , seed = seed ))
417+ tf .random_normal_initializer (stddev = stddev , mean = bavg , seed = seed ),
418+ trainable = trainable )
409419 if self .filter_resnet_dt :
410420 idt = tf .get_variable ('idt_' + str (ii )+ '_' + str (type_i ),
411421 [1 , outputs_size [ii ]],
412422 global_tf_float_precision ,
413- tf .random_normal_initializer (stddev = 0.001 , mean = 1.0 , seed = seed ))
423+ tf .random_normal_initializer (stddev = 0.001 , mean = 1.0 , seed = seed ),
424+ trainable = trainable )
414425 if outputs_size [ii ] == outputs_size [ii - 1 ]:
415426 if self .filter_resnet_dt :
416427 xyz_scatter += activation_fn (tf .matmul (xyz_scatter , w ) + b ) * idt
0 commit comments