@@ -34,29 +34,34 @@ class DPTabulate():
3434 For example, `[[0, 1]]` means no interaction between type 0 and type 1.
3535 activation_function
3636 The activation function in the embedding net. Supported options are {"tanh","gelu"} in common.ACTIVATION_FN_DICT.
37+ suffix : str, optional
38+ The suffix of the scope
3739 """
3840 def __init__ (self ,
3941 model_file : str ,
4042 type_one_side : bool = False ,
4143 exclude_types : List [List [int ]] = [],
42- activation_fn : Callable [[tf .Tensor ], tf .Tensor ] = tf .nn .tanh ) -> None :
44+ activation_fn : Callable [[tf .Tensor ], tf .Tensor ] = tf .nn .tanh ,
45+ suffix : str = "" ,
46+ ) -> None :
4347 """
4448 Constructor
4549 """
4650
4751 self .model_file = model_file
4852 self .type_one_side = type_one_side
4953 self .exclude_types = exclude_types
54+ self .suffix = suffix
5055 if self .type_one_side and len (self .exclude_types ) != 0 :
51- raise RunTimeError ('"type_one_side" is not compatible with "exclude_types"' )
56+ raise RuntimeError ('"type_one_side" is not compatible with "exclude_types"' )
5257
5358 # functype
5459 if activation_fn == ACTIVATION_FN_DICT ["tanh" ]:
5560 self .functype = 1
5661 elif activation_fn == ACTIVATION_FN_DICT ["gelu" ]:
5762 self .functype = 2
5863 else :
59- raise RunTimeError ("Unknown actication function type!" )
64+ raise RuntimeError ("Unknown actication function type!" )
6065 self .activation_fn = activation_fn
6166
6267 self .graph , self .graph_def = load_graph_def (self .model_file )
@@ -72,15 +77,15 @@ def __init__(self,
7277 self .sel_a = self .graph .get_operation_by_name ('DescrptSeA' ).get_attr ('sel_a' )
7378 self .descrpt = self .graph .get_operation_by_name ('DescrptSeA' )
7479
75- self .davg = get_tensor_by_name_from_graph (self .graph , 'descrpt_attr/t_avg' )
76- self .dstd = get_tensor_by_name_from_graph (self .graph , 'descrpt_attr/t_std' )
80+ self .davg = get_tensor_by_name_from_graph (self .graph , f 'descrpt_attr{ self . suffix } /t_avg' )
81+ self .dstd = get_tensor_by_name_from_graph (self .graph , f 'descrpt_attr{ self . suffix } /t_std' )
7782 self .ntypes = get_tensor_by_name_from_graph (self .graph , 'descrpt_attr/ntypes' )
7883
7984
8085 self .rcut = self .descrpt .get_attr ('rcut_r' )
8186 self .rcut_smth = self .descrpt .get_attr ('rcut_r_smth' )
8287
83- self .embedding_net_nodes = get_embedding_net_nodes_from_graph_def (self .graph_def )
88+ self .embedding_net_nodes = get_embedding_net_nodes_from_graph_def (self .graph_def , suffix = self . suffix )
8489
8590 for tt in self .exclude_types :
8691 if (tt [0 ] not in range (self .ntypes )) or (tt [1 ] not in range (self .ntypes )):
@@ -174,14 +179,14 @@ def _get_bias(self):
174179 bias ["layer_" + str (layer )] = []
175180 if self .type_one_side :
176181 for ii in range (0 , self .ntypes ):
177- tensor_value = np .frombuffer (self .embedding_net_nodes ["filter_type_all/bias_" + str ( layer ) + "_" + str ( ii ) ].tensor_content )
178- tensor_shape = tf .TensorShape (self .embedding_net_nodes ["filter_type_all/bias_" + str ( layer ) + "_" + str ( ii ) ].tensor_shape ).as_list ()
182+ tensor_value = np .frombuffer (self .embedding_net_nodes [f "filter_type_all{ self . suffix } /bias_{ layer } _ { ii } " ].tensor_content )
183+ tensor_shape = tf .TensorShape (self .embedding_net_nodes [f "filter_type_all{ self . suffix } /bias_{ layer } _ { ii } " ].tensor_shape ).as_list ()
179184 bias ["layer_" + str (layer )].append (np .reshape (tensor_value , tensor_shape ))
180185 else :
181186 for ii in range (0 , self .ntypes * self .ntypes ):
182187 if (ii // self .ntypes , int (ii % self .ntypes )) not in self .exclude_types :
183- tensor_value = np .frombuffer (self .embedding_net_nodes ["filter_type_" + str ( ii // self .ntypes ) + " /bias_" + str ( layer ) + "_" + str ( int ( ii % self .ntypes )) ].tensor_content )
184- tensor_shape = tf .TensorShape (self .embedding_net_nodes ["filter_type_" + str ( ii // self .ntypes ) + " /bias_" + str ( layer ) + "_" + str ( int ( ii % self .ntypes )) ].tensor_shape ).as_list ()
188+ tensor_value = np .frombuffer (self .embedding_net_nodes [f "filter_type_{ ii // self .ntypes } { self . suffix } /bias_{ layer } _ { ii % self .ntypes } " ].tensor_content )
189+ tensor_shape = tf .TensorShape (self .embedding_net_nodes [f "filter_type_{ ii // self .ntypes } { self . suffix } /bias_{ layer } _ { ii % self .ntypes } " ].tensor_shape ).as_list ()
185190 bias ["layer_" + str (layer )].append (np .reshape (tensor_value , tensor_shape ))
186191 else :
187192 bias ["layer_" + str (layer )].append (np .array ([]))
@@ -193,14 +198,14 @@ def _get_matrix(self):
193198 matrix ["layer_" + str (layer )] = []
194199 if self .type_one_side :
195200 for ii in range (0 , self .ntypes ):
196- tensor_value = np .frombuffer (self .embedding_net_nodes ["filter_type_all/matrix_" + str ( layer ) + "_" + str ( ii ) ].tensor_content )
197- tensor_shape = tf .TensorShape (self .embedding_net_nodes ["filter_type_all/matrix_" + str ( layer ) + "_" + str ( ii ) ].tensor_shape ).as_list ()
201+ tensor_value = np .frombuffer (self .embedding_net_nodes [f "filter_type_all{ self . suffix } /matrix_{ layer } _ { ii } " ].tensor_content )
202+ tensor_shape = tf .TensorShape (self .embedding_net_nodes [f "filter_type_all{ self . suffix } /matrix_{ layer } _ { ii } " ].tensor_shape ).as_list ()
198203 matrix ["layer_" + str (layer )].append (np .reshape (tensor_value , tensor_shape ))
199204 else :
200205 for ii in range (0 , self .ntypes * self .ntypes ):
201206 if (ii // self .ntypes , int (ii % self .ntypes )) not in self .exclude_types :
202- tensor_value = np .frombuffer (self .embedding_net_nodes ["filter_type_" + str ( ii // self .ntypes ) + " /matrix_" + str ( layer ) + "_" + str ( int ( ii % self .ntypes )) ].tensor_content )
203- tensor_shape = tf .TensorShape (self .embedding_net_nodes ["filter_type_" + str ( ii // self .ntypes ) + " /matrix_" + str ( layer ) + "_" + str ( int ( ii % self .ntypes )) ].tensor_shape ).as_list ()
207+ tensor_value = np .frombuffer (self .embedding_net_nodes [f "filter_type_{ ii // self .ntypes } { self . suffix } /matrix_{ layer } _ { ii % self .ntypes } " ].tensor_content )
208+ tensor_shape = tf .TensorShape (self .embedding_net_nodes [f "filter_type_{ ii // self .ntypes } { self . suffix } /matrix_{ layer } _ { ii % self .ntypes } " ].tensor_shape ).as_list ()
204209 matrix ["layer_" + str (layer )].append (np .reshape (tensor_value , tensor_shape ))
205210 else :
206211 matrix ["layer_" + str (layer )].append (np .array ([]))
0 commit comments