@@ -17,6 +17,7 @@ def __init__ (self, jdata):
1717 .add ('resnet_dt' ,bool , default = False ) \
1818 .add ('trainable' ,bool , default = True ) \
1919 .add ('seed' , int ) \
20+ .add ('type_one_side' , bool , default = False ) \
2021 .add ('exclude_types' , list , default = []) \
2122 .add ('set_davg_zero' , bool , default = False ) \
2223 .add ('activation_function' , str , default = 'tanh' ) \
@@ -39,6 +40,9 @@ def __init__ (self, jdata):
3940 self .exclude_types .add ((tt [0 ], tt [1 ]))
4041 self .exclude_types .add ((tt [1 ], tt [0 ]))
4142 self .set_davg_zero = class_data ['set_davg_zero' ]
43+ self .type_one_side = class_data ['type_one_side' ]
44+ if self .type_one_side and len (exclude_types ) != 0 :
45+ raise RuntimeError ('"type_one_side" is not compatible with "exclude_types"' )
4246
4347 # descrpt config
4448 self .sel_r = [ 0 for ii in range (len (self .sel_a )) ]
@@ -244,17 +248,27 @@ def _pass_filter(self,
244248 inputs = tf .reshape (inputs , [- 1 , self .ndescrpt * natoms [0 ]])
245249 output = []
246250 output_qmat = []
247- for type_i in range (self .ntypes ):
248- inputs_i = tf .slice (inputs ,
249- [ 0 , start_index * self .ndescrpt ],
250- [- 1 , natoms [2 + type_i ]* self .ndescrpt ] )
251+ if not self .type_one_side :
252+ for type_i in range (self .ntypes ):
253+ inputs_i = tf .slice (inputs ,
254+ [ 0 , start_index * self .ndescrpt ],
255+ [- 1 , natoms [2 + type_i ]* self .ndescrpt ] )
256+ inputs_i = tf .reshape (inputs_i , [- 1 , self .ndescrpt ])
257+ layer , qmat = self ._filter (tf .cast (inputs_i , self .filter_precision ), type_i , name = 'filter_type_' + str (type_i )+ suffix , natoms = natoms , reuse = reuse , seed = self .seed , trainable = trainable , activation_fn = self .filter_activation_fn )
258+ layer = tf .reshape (layer , [tf .shape (inputs )[0 ], natoms [2 + type_i ] * self .get_dim_out ()])
259+ qmat = tf .reshape (qmat , [tf .shape (inputs )[0 ], natoms [2 + type_i ] * self .get_dim_rot_mat_1 () * 3 ])
260+ output .append (layer )
261+ output_qmat .append (qmat )
262+ start_index += natoms [2 + type_i ]
263+ else :
264+ inputs_i = inputs
251265 inputs_i = tf .reshape (inputs_i , [- 1 , self .ndescrpt ])
252- layer , qmat = self ._filter (tf .cast (inputs_i , self .filter_precision ), type_i , name = 'filter_type_' + str (type_i )+ suffix , natoms = natoms , reuse = reuse , seed = self .seed , trainable = trainable , activation_fn = self .filter_activation_fn )
253- layer = tf .reshape (layer , [tf .shape (inputs )[0 ], natoms [2 + type_i ] * self .get_dim_out ()])
254- qmat = tf .reshape (qmat , [tf .shape (inputs )[0 ], natoms [2 + type_i ] * self .get_dim_rot_mat_1 () * 3 ])
266+ type_i = - 1
267+ layer , qmat = self ._filter (tf .cast (inputs_i , self .filter_precision ), type_i , name = 'filter_type_all' + suffix , natoms = natoms , reuse = reuse , seed = self .seed , trainable = trainable , activation_fn = self .filter_activation_fn )
268+ layer = tf .reshape (layer , [tf .shape (inputs )[0 ], natoms [0 ] * self .get_dim_out ()])
269+ qmat = tf .reshape (qmat , [tf .shape (inputs )[0 ], natoms [0 ] * self .get_dim_rot_mat_1 () * 3 ])
255270 output .append (layer )
256271 output_qmat .append (qmat )
257- start_index += natoms [2 + type_i ]
258272 output = tf .concat (output , axis = 1 )
259273 output_qmat = tf .concat (output_qmat , axis = 1 )
260274 return output , output_qmat
0 commit comments