@@ -421,3 +421,76 @@ def build (self,
421421 count += 1
422422
423423 return tf .reshape (outs , [- 1 ])
424+
425+
426+ class DipoleFittingSeA () :
427+ def __init__ (self , jdata , descrpt ) :
428+ if not isinstance (descrpt , DescrptSeA ) :
429+ raise RuntimeError ('DipoleFittingSeA only supports DescrptSeA' )
430+ self .ntypes = descrpt .get_ntypes ()
431+ self .dim_descrpt = descrpt .get_dim_out ()
432+ args = ClassArg ()\
433+ .add ('neuron' , list , default = [120 ,120 ,120 ], alias = 'n_neuron' )\
434+ .add ('resnet_dt' , bool , default = True )\
435+ .add ('sel_type' , [list ,int ], default = [ii for ii in range (self .ntypes )], alias = 'dipole_type' )\
436+ .add ('seed' , int )
437+ class_data = args .parse (jdata )
438+ self .n_neuron = class_data ['neuron' ]
439+ self .resnet_dt = class_data ['resnet_dt' ]
440+ self .sel_type = class_data ['sel_type' ]
441+ self .seed = class_data ['seed' ]
442+ self .dim_rot_mat_1 = descrpt .get_dim_rot_mat_1 ()
443+ self .dim_rot_mat = self .dim_rot_mat_1 * 3
444+ self .useBN = False
445+
446+ def get_sel_type (self ):
447+ return self .sel_type
448+
449+ def build (self ,
450+ input_d ,
451+ rot_mat ,
452+ natoms ,
453+ reuse = None ,
454+ suffix = '' ) :
455+ start_index = 0
456+ inputs = tf .reshape (input_d , [- 1 , self .dim_descrpt * natoms [0 ]])
457+ rot_mat = tf .reshape (rot_mat , [- 1 , self .dim_rot_mat * natoms [0 ]])
458+ shape = inputs .get_shape ().as_list ()
459+
460+ count = 0
461+ for type_i in range (self .ntypes ):
462+ # cut-out inputs
463+ inputs_i = tf .slice (inputs ,
464+ [ 0 , start_index * self .dim_descrpt ],
465+ [- 1 , natoms [2 + type_i ]* self .dim_descrpt ] )
466+ inputs_i = tf .reshape (inputs_i , [- 1 , self .dim_descrpt ])
467+ rot_mat_i = tf .slice (rot_mat ,
468+ [ 0 , start_index * self .dim_rot_mat ],
469+ [- 1 , natoms [2 + type_i ]* self .dim_rot_mat ] )
470+ rot_mat_i = tf .reshape (rot_mat_i , [- 1 , self .dim_rot_mat_1 , 3 ])
471+ start_index += natoms [2 + type_i ]
472+ if not type_i in self .sel_type :
473+ continue
474+ layer = inputs_i
475+ for ii in range (0 ,len (self .n_neuron )) :
476+ if ii >= 1 and self .n_neuron [ii ] == self .n_neuron [ii - 1 ] :
477+ layer += one_layer (layer , self .n_neuron [ii ], name = 'layer_' + str (ii )+ '_type_' + str (type_i )+ suffix , reuse = reuse , seed = self .seed , use_timestep = self .resnet_dt )
478+ else :
479+ layer = one_layer (layer , self .n_neuron [ii ], name = 'layer_' + str (ii )+ '_type_' + str (type_i )+ suffix , reuse = reuse , seed = self .seed )
480+ # (nframes x natoms) x naxis
481+ final_layer = one_layer (layer , self .dim_rot_mat_1 , activation_fn = None , name = 'final_layer_type_' + str (type_i )+ suffix , reuse = reuse , seed = self .seed )
482+ # (nframes x natoms) x 1 * naxis
483+ final_layer = tf .reshape (final_layer , [tf .shape (inputs )[0 ] * natoms [2 + type_i ], 1 , self .dim_rot_mat_1 ])
484+ # (nframes x natoms) x 1 x 3(coord)
485+ final_layer = tf .matmul (final_layer , rot_mat_i )
486+ # nframes x natoms x 3
487+ final_layer = tf .reshape (final_layer , [tf .shape (inputs )[0 ], natoms [2 + type_i ], 3 ])
488+
489+ # concat the results
490+ if count == 0 :
491+ outs = final_layer
492+ else :
493+ outs = tf .concat ([outs , final_layer ], axis = 1 )
494+ count += 1
495+
496+ return tf .reshape (outs , [- 1 ])
0 commit comments