1313from deepmd .utils .tabulate import DPTabulate
1414from deepmd .utils .type_embed import embed_atom_type
1515from deepmd .utils .sess import run_sess
16- from deepmd .utils .graph import load_graph_def , get_tensor_by_name_from_graph
16+ from deepmd .utils .graph import load_graph_def , get_tensor_by_name_from_graph , get_embedding_net_variables
1717from .descriptor import Descriptor
1818
1919class DescrptSeA (Descriptor ):
@@ -433,10 +433,10 @@ def build (self,
433433 tf .summary .histogram ('nlist' , self .nlist )
434434
435435 self .descrpt_reshape = tf .reshape (self .descrpt , [- 1 , self .ndescrpt ])
436- self .descrpt_reshape = tf .identity (self .descrpt_reshape , name = 'o_rmat' )
437- self .descrpt_deriv = tf .identity (self .descrpt_deriv , name = 'o_rmat_deriv' )
438- self .rij = tf .identity (self .rij , name = 'o_rij' )
439- self .nlist = tf .identity (self .nlist , name = 'o_nlist' )
436+ self .descrpt_reshape = tf .identity (self .descrpt_reshape , name = 'o_rmat' + suffix )
437+ self .descrpt_deriv = tf .identity (self .descrpt_deriv , name = 'o_rmat_deriv' + suffix )
438+ self .rij = tf .identity (self .rij , name = 'o_rij' + suffix )
439+ self .nlist = tf .identity (self .nlist , name = 'o_nlist' + suffix )
440440
441441 self .dout , self .qmat = self ._pass_filter (self .descrpt_reshape ,
442442 atype ,
@@ -456,6 +456,21 @@ def get_rot_mat(self) -> tf.Tensor:
456456 """
457457 return self .qmat
458458
459+ def get_tensor_names (self , suffix : str = "" ) -> Tuple [str ]:
460+ """Get names of tensors.
461+
462+ Parameters
463+ ----------
464+ suffix : str
465+ The suffix of the scope
466+
467+ Returns
468+ -------
469+ Tuple[str]
470+ Names of tensors
471+ """
472+ return (f'o_rmat{ suffix } :0' , f'o_rmat_deriv{ suffix } :0' , f'o_rij{ suffix } :0' , f'o_nlist{ suffix } :0' )
473+
459474 def pass_tensors_from_frz_model (self ,
460475 descrpt_reshape : tf .Tensor ,
461476 descrpt_deriv : tf .Tensor ,
@@ -481,60 +496,21 @@ def pass_tensors_from_frz_model(self,
481496 self .descrpt_deriv = descrpt_deriv
482497 self .descrpt_reshape = descrpt_reshape
483498
484- def get_feed_dict (self ,
485- coord_ ,
486- atype_ ,
487- natoms ,
488- box ,
489- mesh ):
490- """
491- generate the deed_dict for current descriptor
492-
493- Parameters
494- ----------
495- coord_
496- The coordinate of atoms
497- atype_
498- The type of atoms
499- natoms
500- The number of atoms. This tensor has the length of Ntypes + 2
501- natoms[0]: number of local atoms
502- natoms[1]: total number of atoms held by this processor
503- natoms[i]: 2 <= i < Ntypes+2, number of type i atoms
504- box
505- The box. Can be generated by deepmd.model.make_stat_input
506- mesh
507- For historical reasons, only the length of the Tensor matters.
508- if size of mesh == 6, pbc is assumed.
509- if size of mesh == 0, no-pbc is assumed.
510-
511- Returns
512- -------
513- feed_dict
514- The output feed_dict of current descriptor
515- """
516- feed_dict = {
517- 't_coord:0' :coord_ ,
518- 't_type:0' :atype_ ,
519- 't_natoms:0' :natoms ,
520- 't_box:0' :box ,
521- 't_mesh:0' :mesh
522- }
523- return feed_dict
524-
525-
526499 def init_variables (self ,
527- embedding_net_variables : dict
500+ model_file : str ,
501+ suffix : str = "" ,
528502 ) -> None :
529503 """
530504 Init the embedding net variables with the given dict
531505
532506 Parameters
533507 ----------
534- embedding_net_variables
535- The input dict which stores the embedding net variables
508+ model_file : str
509+ The input frozen model file
510+ suffix : str, optional
511+ The suffix of the scope
536512 """
537- self .embedding_net_variables = embedding_net_variables
513+ self .embedding_net_variables = get_embedding_net_variables ( model_file , suffix = suffix )
538514
539515
540516 def prod_force_virial (self ,
0 commit comments