Skip to content

Commit 97be2f5

Browse files
njzjzdenghuilu
andauthored
support init_frz_model for hybrid descriptor (#1112)
* support init_frz_model for hybrid descriptor Refactors some methods to implement it. Also fixes some typos. * rename `graph_def` to `model_file` Co-authored-by: Denghui Lu <[email protected]> Co-authored-by: Denghui Lu <[email protected]>
1 parent 60797e0 commit 97be2f5

File tree

7 files changed

+150
-75
lines changed

7 files changed

+150
-75
lines changed

deepmd/descriptor/descriptor.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -280,26 +280,65 @@ def get_feed_dict(self,
280280
feed_dict : dict[str, tf.Tensor]
281281
The output feed_dict of current descriptor
282282
"""
283-
# TODO: currently only SeA has this method, but I think the method can be
284-
# moved here as it doesn't contain anything related to a specific descriptor
285-
raise NotImplementedError
283+
feed_dict = {
284+
't_coord:0' :coord_,
285+
't_type:0' :atype_,
286+
't_natoms:0' :natoms,
287+
't_box:0' :box,
288+
't_mesh:0' :mesh
289+
}
290+
return feed_dict
286291

287292
def init_variables(self,
288-
embedding_net_variables: dict
289-
) -> None:
293+
model_file: str,
294+
suffix : str = "",
295+
) -> None:
290296
"""
291297
Init the embedding net variables with the given dict
292298
293299
Parameters
294300
----------
295-
embedding_net_variables
296-
The input dict which stores the embedding net variables
301+
model_file : str
302+
The input model file
303+
suffix : str, optional
304+
The suffix of the scope
297305
298306
Notes
299307
-----
300308
This method is called by others when the descriptor supported initialization from the given variables.
301309
"""
302-
# TODO: currently only SeA has this method, but I think the method can be
303-
# moved here as it doesn't contain anything related to a specific descriptor
304310
raise NotImplementedError(
305311
"Descriptor %s doesn't support initialization from the given variables!" % type(self).__name__)
312+
313+
def get_tensor_names(self, suffix : str = "") -> Tuple[str]:
314+
"""Get names of tensors.
315+
316+
Parameters
317+
----------
318+
suffix : str
319+
The suffix of the scope
320+
321+
Returns
322+
-------
323+
Tuple[str]
324+
Names of tensors
325+
"""
326+
raise NotImplementedError("Descriptor %s doesn't support this property!" % type(self).__name__)
327+
328+
def pass_tensors_from_frz_model(self,
329+
*tensors : tf.Tensor,
330+
) -> None:
331+
"""
332+
Pass the descrpt_reshape tensor as well as descrpt_deriv tensor from the frz graph_def
333+
334+
Parameters
335+
----------
336+
*tensors : tf.Tensor
337+
passed tensors
338+
339+
Notes
340+
-----
341+
The number of parameters in the method must be equal to the numbers of returns in
342+
:meth:`get_tensor_names`.
343+
"""
344+
raise NotImplementedError("Descriptor %s doesn't support this method!" % type(self).__name__)

deepmd/descriptor/hybrid.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,3 +253,55 @@ def enable_compression(self,
253253
"""
254254
for idx, ii in enumerate(self.descrpt_list):
255255
ii.enable_compression(min_nbor_dist, model_file, table_extrapolate, table_stride_1, table_stride_2, check_frequency, suffix=f"{suffix}_{idx}")
256+
257+
def init_variables(self,
258+
model_file : str,
259+
suffix : str = "",
260+
) -> None:
261+
"""
262+
Init the embedding net variables with the given dict
263+
264+
Parameters
265+
----------
266+
model_file : str
267+
The input frozen model file
268+
suffix : str, optional
269+
The suffix of the scope
270+
"""
271+
for idx, ii in enumerate(self.descrpt_list):
272+
ii.init_variables(model_file, suffix=f"{suffix}_{idx}")
273+
274+
def get_tensor_names(self, suffix : str = "") -> Tuple[str]:
275+
"""Get names of tensors.
276+
277+
Parameters
278+
----------
279+
suffix : str
280+
The suffix of the scope
281+
282+
Returns
283+
-------
284+
Tuple[str]
285+
Names of tensors
286+
"""
287+
tensor_names = []
288+
for idx, ii in enumerate(self.descrpt_list):
289+
tensor_names.extend(ii.get_tensor_names(suffix=f"{suffix}_{idx}"))
290+
return tuple(tensor_names)
291+
292+
def pass_tensors_from_frz_model(self,
293+
*tensors : tf.Tensor,
294+
) -> None:
295+
"""
296+
Pass the descrpt_reshape tensor as well as descrpt_deriv tensor from the frz graph_def
297+
298+
Parameters
299+
----------
300+
*tensors : tf.Tensor
301+
passed tensors
302+
"""
303+
jj = 0
304+
for ii in self.descrpt_list:
305+
n_tensors = len(ii.get_tensor_names())
306+
ii.pass_tensors_from_frz_model(*tensors[jj:jj+n_tensors])
307+
jj += n_tensors

deepmd/descriptor/se_a.py

Lines changed: 27 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from deepmd.utils.tabulate import DPTabulate
1414
from deepmd.utils.type_embed import embed_atom_type
1515
from deepmd.utils.sess import run_sess
16-
from deepmd.utils.graph import load_graph_def, get_tensor_by_name_from_graph
16+
from deepmd.utils.graph import load_graph_def, get_tensor_by_name_from_graph, get_embedding_net_variables
1717
from .descriptor import Descriptor
1818

1919
class DescrptSeA (Descriptor):
@@ -433,10 +433,10 @@ def build (self,
433433
tf.summary.histogram('nlist', self.nlist)
434434

435435
self.descrpt_reshape = tf.reshape(self.descrpt, [-1, self.ndescrpt])
436-
self.descrpt_reshape = tf.identity(self.descrpt_reshape, name = 'o_rmat')
437-
self.descrpt_deriv = tf.identity(self.descrpt_deriv, name = 'o_rmat_deriv')
438-
self.rij = tf.identity(self.rij, name = 'o_rij')
439-
self.nlist = tf.identity(self.nlist, name = 'o_nlist')
436+
self.descrpt_reshape = tf.identity(self.descrpt_reshape, name = 'o_rmat' + suffix)
437+
self.descrpt_deriv = tf.identity(self.descrpt_deriv, name = 'o_rmat_deriv' + suffix)
438+
self.rij = tf.identity(self.rij, name = 'o_rij' + suffix)
439+
self.nlist = tf.identity(self.nlist, name = 'o_nlist' + suffix)
440440

441441
self.dout, self.qmat = self._pass_filter(self.descrpt_reshape,
442442
atype,
@@ -456,6 +456,21 @@ def get_rot_mat(self) -> tf.Tensor:
456456
"""
457457
return self.qmat
458458

459+
def get_tensor_names(self, suffix : str = "") -> Tuple[str]:
460+
"""Get names of tensors.
461+
462+
Parameters
463+
----------
464+
suffix : str
465+
The suffix of the scope
466+
467+
Returns
468+
-------
469+
Tuple[str]
470+
Names of tensors
471+
"""
472+
return (f'o_rmat{suffix}:0', f'o_rmat_deriv{suffix}:0', f'o_rij{suffix}:0', f'o_nlist{suffix}:0')
473+
459474
def pass_tensors_from_frz_model(self,
460475
descrpt_reshape : tf.Tensor,
461476
descrpt_deriv : tf.Tensor,
@@ -481,60 +496,21 @@ def pass_tensors_from_frz_model(self,
481496
self.descrpt_deriv = descrpt_deriv
482497
self.descrpt_reshape = descrpt_reshape
483498

484-
def get_feed_dict(self,
485-
coord_,
486-
atype_,
487-
natoms,
488-
box,
489-
mesh):
490-
"""
491-
generate the deed_dict for current descriptor
492-
493-
Parameters
494-
----------
495-
coord_
496-
The coordinate of atoms
497-
atype_
498-
The type of atoms
499-
natoms
500-
The number of atoms. This tensor has the length of Ntypes + 2
501-
natoms[0]: number of local atoms
502-
natoms[1]: total number of atoms held by this processor
503-
natoms[i]: 2 <= i < Ntypes+2, number of type i atoms
504-
box
505-
The box. Can be generated by deepmd.model.make_stat_input
506-
mesh
507-
For historical reasons, only the length of the Tensor matters.
508-
if size of mesh == 6, pbc is assumed.
509-
if size of mesh == 0, no-pbc is assumed.
510-
511-
Returns
512-
-------
513-
feed_dict
514-
The output feed_dict of current descriptor
515-
"""
516-
feed_dict = {
517-
't_coord:0' :coord_,
518-
't_type:0' :atype_,
519-
't_natoms:0' :natoms,
520-
't_box:0' :box,
521-
't_mesh:0' :mesh
522-
}
523-
return feed_dict
524-
525-
526499
def init_variables(self,
527-
embedding_net_variables: dict
500+
model_file : str,
501+
suffix : str = "",
528502
) -> None:
529503
"""
530504
Init the embedding net variables with the given dict
531505
532506
Parameters
533507
----------
534-
embedding_net_variables
535-
The input dict which stores the embedding net variables
508+
model_file : str
509+
The input frozen model file
510+
suffix : str, optional
511+
The suffix of the scope
536512
"""
537-
self.embedding_net_variables = embedding_net_variables
513+
self.embedding_net_variables = get_embedding_net_variables(model_file, suffix = suffix)
538514

539515

540516
def prod_force_virial(self,

deepmd/model/ener.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,11 @@ def build (self,
173173
name = 'descrpt_attr/ntypes',
174174
dtype = tf.int32)
175175
feed_dict = self.descrpt.get_feed_dict(coord_, atype_, natoms, box, mesh)
176-
return_elements = ['o_rmat:0', 'o_rmat_deriv:0', 'o_rij:0', 'o_nlist:0', 'o_descriptor:0']
177-
descrpt_reshape, descrpt_deriv, rij, nlist, dout \
176+
return_elements = [*self.descrpt.get_tensor_names(), 'o_descriptor:0']
177+
imported_tensors \
178178
= self._import_graph_def_from_frz_model(frz_model, feed_dict, return_elements)
179-
self.descrpt.pass_tensors_from_frz_model(descrpt_reshape, descrpt_deriv, rij, nlist)
179+
dout = imported_tensors[-1]
180+
self.descrpt.pass_tensors_from_frz_model(*imported_tensors[:-1])
180181

181182

182183
if self.srtab is not None :

deepmd/model/tensor.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from deepmd.env import tf
55
from deepmd.common import ClassArg
6-
from deepmd.env import global_cvt_2_ener_float, MODEL_VERSION
6+
from deepmd.env import global_cvt_2_ener_float, MODEL_VERSION, GLOBAL_TF_FLOAT_PRECISION
77
from deepmd.env import op_module
88
from deepmd.utils.graph import load_graph_def
99
from .model_stat import make_stat_input, merge_sys_stat
@@ -138,10 +138,11 @@ def build (self,
138138
name = 'descrpt_attr/ntypes',
139139
dtype = tf.int32)
140140
feed_dict = self.descrpt.get_feed_dict(coord_, atype_, natoms, box, mesh)
141-
return_elements = ['o_rmat:0', 'o_rmat_deriv:0', 'o_rij:0', 'o_nlist:0', 'o_descriptor:0']
142-
descrpt_reshape, descrpt_deriv, rij, nlist, dout \
141+
return_elements = [*self.descrpt.get_tensor_names(), 'o_descriptor:0']
142+
imported_tensors \
143143
= self._import_graph_def_from_frz_model(frz_model, feed_dict, return_elements)
144-
self.descrpt.pass_tensors_from_frz_model(descrpt_reshape, descrpt_deriv, rij, nlist)
144+
dout = imported_tensors[-1]
145+
self.descrpt.pass_tensors_from_frz_model(*imported_tensors[:-1])
145146

146147
rot_mat = self.descrpt.get_rot_mat()
147148
rot_mat = tf.identity(rot_mat, name = 'o_rot_mat'+suffix)

deepmd/train/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -691,7 +691,7 @@ def _init_from_frz_model(self):
691691
# self.frz_model will control the self.model to import the descriptor from the given frozen model instead of building from scratch...
692692
# initialize fitting net with the given compressed frozen model
693693
if self.model_type == 'original_model':
694-
self.descrpt.init_variables(get_embedding_net_variables(self.run_opt.init_frz_model))
694+
self.descrpt.init_variables(self.run_opt.init_frz_model)
695695
self.fitting.init_variables(get_fitting_net_variables(self.run_opt.init_frz_model))
696696
tf.constant("original_model", name = 'model_type', dtype = tf.string)
697697
elif self.model_type == 'compressed_model':

deepmd/utils/graph.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def get_tensor_by_type(node,
108108
elif data_type == np.float32:
109109
tensor = np.array(node.float_val)
110110
else:
111-
raise RunTimeError('model compression does not support the half precision')
111+
raise RuntimeError('model compression does not support the half precision')
112112
return tensor
113113

114114

@@ -139,40 +139,44 @@ def get_embedding_net_nodes_from_graph_def(graph_def: tf.GraphDef, suffix: str =
139139
return embedding_net_nodes
140140

141141

142-
def get_embedding_net_nodes(model_file: str) -> Dict:
142+
def get_embedding_net_nodes(model_file: str, suffix: str = "") -> Dict:
143143
"""
144144
Get the embedding net nodes with the given frozen model(model_file)
145145
146146
Parameters
147147
----------
148148
model_file
149149
The input frozen model path
150+
suffix : str, optional
151+
The suffix of the scope
150152
151153
Returns
152154
----------
153155
Dict
154156
The embedding net nodes with the given frozen model
155157
"""
156158
_, graph_def = load_graph_def(model_file)
157-
return get_embedding_net_nodes_from_graph_def(graph_def)
159+
return get_embedding_net_nodes_from_graph_def(graph_def, suffix=suffix)
158160

159161

160-
def get_embedding_net_variables_from_graph_def(graph_def : tf.GraphDef) -> Dict:
162+
def get_embedding_net_variables_from_graph_def(graph_def : tf.GraphDef, suffix: str = "") -> Dict:
161163
"""
162164
Get the embedding net variables with the given tf.GraphDef object
163165
164166
Parameters
165167
----------
166168
graph_def
167169
The input tf.GraphDef object
170+
suffix : str, optional
171+
The suffix of the scope
168172
169173
Returns
170174
----------
171175
Dict
172176
The embedding net variables within the given tf.GraphDef object
173177
"""
174178
embedding_net_variables = {}
175-
embedding_net_nodes = get_embedding_net_nodes_from_graph_def(graph_def)
179+
embedding_net_nodes = get_embedding_net_nodes_from_graph_def(graph_def, suffix=suffix)
176180
for item in embedding_net_nodes:
177181
node = embedding_net_nodes[item]
178182
dtype = tf.as_dtype(node.dtype).as_numpy_dtype
@@ -184,22 +188,24 @@ def get_embedding_net_variables_from_graph_def(graph_def : tf.GraphDef) -> Dict:
184188
embedding_net_variables[item] = np.reshape(tensor_value, tensor_shape)
185189
return embedding_net_variables
186190

187-
def get_embedding_net_variables(model_file : str) -> Dict:
191+
def get_embedding_net_variables(model_file : str, suffix: str = "") -> Dict:
188192
"""
189193
Get the embedding net variables with the given frozen model(model_file)
190194
191195
Parameters
192196
----------
193197
model_file
194198
The input frozen model path
199+
suffix : str, optional
200+
The suffix of the scope
195201
196202
Returns
197203
----------
198204
Dict
199205
The embedding net variables within the given frozen model
200206
"""
201207
_, graph_def = load_graph_def(model_file)
202-
return get_embedding_net_variables_from_graph_def(graph_def)
208+
return get_embedding_net_variables_from_graph_def(graph_def, suffix=suffix)
203209

204210

205211
def get_fitting_net_nodes_from_graph_def(graph_def: tf.GraphDef) -> Dict:

0 commit comments

Comments
 (0)