@@ -16,7 +16,9 @@ def __init__ (self, jdata):
1616 .add ('axis_neuron' , int , default = 4 , alias = 'n_axis_neuron' ) \
1717 .add ('resnet_dt' ,bool , default = False ) \
1818 .add ('trainable' ,bool , default = True ) \
19- .add ('seed' , int )
19+ .add ('seed' , int ) \
20+ .add ('exclude_types' , list , default = []) \
21+ .add ('set_davg_zero' , bool , default = False )
2022 class_data = args .parse (jdata )
2123 self .sel_a = class_data ['sel' ]
2224 self .rcut_r = class_data ['rcut' ]
@@ -26,6 +28,13 @@ def __init__ (self, jdata):
2628 self .filter_resnet_dt = class_data ['resnet_dt' ]
2729 self .seed = class_data ['seed' ]
2830 self .trainable = class_data ['trainable' ]
31+ exclude_types = class_data ['exclude_types' ]
32+ self .exclude_types = set ()
33+ for tt in exclude_types :
34+ assert (len (tt ) == 2 )
35+ self .exclude_types .add ((tt [0 ], tt [1 ]))
36+ self .exclude_types .add ((tt [1 ], tt [0 ]))
37+ self .set_davg_zero = class_data ['set_davg_zero' ]
2938
3039 # descrpt config
3140 self .sel_r = [ 0 for ii in range (len (self .sel_a )) ]
@@ -124,7 +133,8 @@ def compute_input_stats (self,
124133 all_davg .append (davg )
125134 all_dstd .append (dstd )
126135
127- self .davg = np .array (all_davg )
136+ if not self .set_davg_zero :
137+ self .davg = np .array (all_davg )
128138 self .dstd = np .array (all_dstd )
129139
130140
@@ -235,7 +245,7 @@ def _pass_filter(self,
235245 [ 0 , start_index * self .ndescrpt ],
236246 [- 1 , natoms [2 + type_i ]* self .ndescrpt ] )
237247 inputs_i = tf .reshape (inputs_i , [- 1 , self .ndescrpt ])
238- layer , qmat = self ._filter (inputs_i , name = 'filter_type_' + str (type_i )+ suffix , natoms = natoms , reuse = reuse , seed = self .seed , trainable = trainable )
248+ layer , qmat = self ._filter (inputs_i , type_i , name = 'filter_type_' + str (type_i )+ suffix , natoms = natoms , reuse = reuse , seed = self .seed , trainable = trainable )
239249 layer = tf .reshape (layer , [tf .shape (inputs )[0 ], natoms [2 + type_i ] * self .get_dim_out ()])
240250 qmat = tf .reshape (qmat , [tf .shape (inputs )[0 ], natoms [2 + type_i ] * self .get_dim_rot_mat_1 () * 3 ])
241251 output .append (layer )
@@ -300,6 +310,7 @@ def _compute_std (self,sumv2, sumv, sumn) :
300310
301311 def _filter (self ,
302312 inputs ,
313+ type_input ,
303314 natoms ,
304315 activation_fn = tf .nn .tanh ,
305316 stddev = 1.0 ,
@@ -326,35 +337,39 @@ def _filter(self,
326337 # with (natom x nei_type_i) x 4
327338 inputs_reshape = tf .reshape (inputs_i , [- 1 , 4 ])
328339 xyz_scatter = tf .reshape (tf .slice (inputs_reshape , [0 ,0 ],[- 1 ,1 ]),[- 1 ,1 ])
329- for ii in range (1 , len (outputs_size )):
330- w = tf .get_variable ('matrix_' + str (ii )+ '_' + str (type_i ),
331- [outputs_size [ii - 1 ], outputs_size [ii ]],
332- global_tf_float_precision ,
333- tf .random_normal_initializer (stddev = stddev / np .sqrt (outputs_size [ii ]+ outputs_size [ii - 1 ]), seed = seed ),
334- trainable = trainable )
335- b = tf .get_variable ('bias_' + str (ii )+ '_' + str (type_i ),
336- [1 , outputs_size [ii ]],
337- global_tf_float_precision ,
338- tf .random_normal_initializer (stddev = stddev , mean = bavg , seed = seed ),
339- trainable = trainable )
340- if self .filter_resnet_dt :
341- idt = tf .get_variable ('idt_' + str (ii )+ '_' + str (type_i ),
342- [1 , outputs_size [ii ]],
343- global_tf_float_precision ,
344- tf .random_normal_initializer (stddev = 0.001 , mean = 1.0 , seed = seed ),
345- trainable = trainable )
346- if outputs_size [ii ] == outputs_size [ii - 1 ]:
347- if self .filter_resnet_dt :
348- xyz_scatter += activation_fn (tf .matmul (xyz_scatter , w ) + b ) * idt
349- else :
350- xyz_scatter += activation_fn (tf .matmul (xyz_scatter , w ) + b )
351- elif outputs_size [ii ] == outputs_size [ii - 1 ] * 2 :
352- if self .filter_resnet_dt :
353- xyz_scatter = tf .concat ([xyz_scatter ,xyz_scatter ], 1 ) + activation_fn (tf .matmul (xyz_scatter , w ) + b ) * idt
354- else :
355- xyz_scatter = tf .concat ([xyz_scatter ,xyz_scatter ], 1 ) + activation_fn (tf .matmul (xyz_scatter , w ) + b )
356- else :
357- xyz_scatter = activation_fn (tf .matmul (xyz_scatter , w ) + b )
340+ if (type_input , type_i ) not in self .exclude_types :
341+ for ii in range (1 , len (outputs_size )):
342+ w = tf .get_variable ('matrix_' + str (ii )+ '_' + str (type_i ),
343+ [outputs_size [ii - 1 ], outputs_size [ii ]],
344+ global_tf_float_precision ,
345+ tf .random_normal_initializer (stddev = stddev / np .sqrt (outputs_size [ii ]+ outputs_size [ii - 1 ]), seed = seed ),
346+ trainable = trainable )
347+ b = tf .get_variable ('bias_' + str (ii )+ '_' + str (type_i ),
348+ [1 , outputs_size [ii ]],
349+ global_tf_float_precision ,
350+ tf .random_normal_initializer (stddev = stddev , mean = bavg , seed = seed ),
351+ trainable = trainable )
352+ if self .filter_resnet_dt :
353+ idt = tf .get_variable ('idt_' + str (ii )+ '_' + str (type_i ),
354+ [1 , outputs_size [ii ]],
355+ global_tf_float_precision ,
356+ tf .random_normal_initializer (stddev = 0.001 , mean = 1.0 , seed = seed ),
357+ trainable = trainable )
358+ if outputs_size [ii ] == outputs_size [ii - 1 ]:
359+ if self .filter_resnet_dt :
360+ xyz_scatter += activation_fn (tf .matmul (xyz_scatter , w ) + b ) * idt
361+ else :
362+ xyz_scatter += activation_fn (tf .matmul (xyz_scatter , w ) + b )
363+ elif outputs_size [ii ] == outputs_size [ii - 1 ] * 2 :
364+ if self .filter_resnet_dt :
365+ xyz_scatter = tf .concat ([xyz_scatter ,xyz_scatter ], 1 ) + activation_fn (tf .matmul (xyz_scatter , w ) + b ) * idt
366+ else :
367+ xyz_scatter = tf .concat ([xyz_scatter ,xyz_scatter ], 1 ) + activation_fn (tf .matmul (xyz_scatter , w ) + b )
368+ else :
369+ xyz_scatter = activation_fn (tf .matmul (xyz_scatter , w ) + b )
370+ else :
371+ w = tf .zeros ((outputs_size [0 ], outputs_size [- 1 ]), dtype = global_tf_float_precision )
372+ xyz_scatter = tf .matmul (xyz_scatter , w )
358373 # natom x nei_type_i x out_size
359374 xyz_scatter = tf .reshape (xyz_scatter , (- 1 , shape_i [1 ]// 4 , outputs_size [- 1 ]))
360375 xyz_scatter_total .append (xyz_scatter )
0 commit comments