Skip to content

Commit d5a0f6f

Browse files
authored
Enable init-frz-model support for the original model (#1102)
* enable init-frz-model support for the original model within the dp train interface * add init_variables method for ABC * add doc for embedding_net method
1 parent 32ccbb5 commit d5a0f6f

File tree

4 files changed

+79
-24
lines changed

4 files changed

+79
-24
lines changed

deepmd/descriptor/descriptor.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,3 +283,23 @@ def get_feed_dict(self,
283283
# TODO: currently only SeA has this method, but I think the method can be
284284
# moved here as it doesn't contain anything related to a specific descriptor
285285
raise NotImplementedError
286+
287+
def init_variables(self,
288+
embedding_net_variables: dict
289+
) -> None:
290+
"""
291+
Init the embedding net variables with the given dict
292+
293+
Parameters
294+
----------
295+
embedding_net_variables
296+
The input dict which stores the embedding net variables
297+
298+
Notes
299+
-----
300+
This method is called by others when the descriptor supported initialization from the given variables.
301+
"""
302+
# TODO: currently only SeA has this method, but I think the method can be
303+
# moved here as it doesn't contain anything related to a specific descriptor
304+
raise NotImplementedError(
305+
"Descriptor %s doesn't support initialization from the given variables!" % type(self).__name__)

deepmd/descriptor/se_a.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def __init__ (self,
157157
self.dstd = None
158158
self.davg = None
159159
self.compress = False
160+
self.embedding_net_variables = None
160161
self.place_holders = {}
161162
nei_type = np.array([])
162163
for ii in range(self.ntypes):
@@ -521,6 +522,21 @@ def get_feed_dict(self,
521522
}
522523
return feed_dict
523524

525+
526+
def init_variables(self,
527+
embedding_net_variables: dict
528+
) -> None:
529+
"""
530+
Init the embedding net variables with the given dict
531+
532+
Parameters
533+
----------
534+
embedding_net_variables
535+
The input dict which stores the embedding net variables
536+
"""
537+
self.embedding_net_variables = embedding_net_variables
538+
539+
524540
def prod_force_virial(self,
525541
atom_ener : tf.Tensor,
526542
natoms : tf.Tensor
@@ -766,7 +782,8 @@ def _filter_lower(
766782
bavg = bavg,
767783
seed = self.seed,
768784
trainable = trainable,
769-
uniform_seed = self.uniform_seed)
785+
uniform_seed = self.uniform_seed,
786+
initial_variables = self.embedding_net_variables)
770787
if (not self.uniform_seed) and (self.seed is not None): self.seed += self.seed_shift
771788
else:
772789
# we can safely return the final xyz_scatter filled with zero directly

deepmd/train/trainer.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from deepmd.utils.neighbor_stat import NeighborStat
2727
from deepmd.utils.sess import run_sess
2828
from deepmd.utils.type_embed import TypeEmbedNet
29-
from deepmd.utils.graph import get_tensor_by_name, get_fitting_net_variables
29+
from deepmd.utils.graph import get_tensor_by_name, get_embedding_net_variables, get_fitting_net_variables
3030

3131
from tensorflow.python.client import timeline
3232
from deepmd.env import op_module
@@ -278,7 +278,6 @@ def _init_param(self, jdata):
278278
# if init the graph with the frozen model
279279
self.frz_model = None
280280
self.model_type = None
281-
self.init_from_frz_model = False
282281

283282

284283
def build (self,
@@ -694,14 +693,17 @@ def _init_from_frz_model(self):
694693
"which is not supported by the 'dp train init-frz-model' interface. " % self.run_opt.init_frz_model
695694
) from e
696695

696+
if self.fitting_type != 'ener':
697+
raise RuntimeError("The 'dp train init-frz-model' command only supports the 'ener' type fitting net currently!")
697698
# self.frz_model will control the self.model to import the descriptor from the given frozen model instead of building from scratch...
698699
# initialize fitting net with the given compressed frozen model
699-
if self.model_type == 'compressed_model' and self.fitting_type == 'ener':
700-
self.init_from_frz_model = True
700+
if self.model_type == 'original_model':
701+
self.descrpt.init_variables(get_embedding_net_variables(self.run_opt.init_frz_model))
702+
self.fitting.init_variables(get_fitting_net_variables(self.run_opt.init_frz_model))
703+
tf.constant("original_model", name = 'model_type', dtype = tf.string)
704+
elif self.model_type == 'compressed_model':
701705
self.frz_model = self.run_opt.init_frz_model
702706
self.fitting.init_variables(get_fitting_net_variables(self.frz_model))
703707
tf.constant("compressed_model", name = 'model_type', dtype = tf.string)
704-
elif self.fitting_type != 'ener':
705-
raise RuntimeError("The 'dp train init-frz-model' command only supports the 'ener' type fitting net currently!")
706708
else:
707-
raise RuntimeError("The 'dp train init-frz-model' command only supports the compressed model currently!")
709+
raise RuntimeError("Unknown model type %s" % self.model_type)

deepmd/utils/network.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ def embedding_net(xx,
9292
bavg = 0.0,
9393
seed = None,
9494
trainable = True,
95-
uniform_seed = False):
95+
uniform_seed = False,
96+
initial_variables = None):
9697
r"""The embedding network.
9798
9899
The embedding network function :math:`\mathcal{N}` is constructed by is the
@@ -141,6 +142,11 @@ def embedding_net(xx,
141142
Random seed for initializing network parameters
142143
trainable: boolean
143144
If the network is trainable
145+
uniform_seed : boolean
146+
Only for the purpose of backward compatibility, retrieves the old behavior of using the random seed
147+
initial_variables : dict
148+
The input dict which stores the embedding net variables
149+
144150
145151
References
146152
----------
@@ -152,37 +158,47 @@ def embedding_net(xx,
152158
outputs_size = [input_shape[1]] + network_size
153159

154160
for ii in range(1, len(outputs_size)):
155-
w = tf.get_variable('matrix_'+str(ii)+name_suffix,
161+
w_initializer = tf.random_normal_initializer(
162+
stddev=stddev/np.sqrt(outputs_size[ii]+outputs_size[ii-1]),
163+
seed = seed if (seed is None or uniform_seed) else seed + ii*3+0
164+
)
165+
b_initializer = tf.random_normal_initializer(
166+
stddev=stddev,
167+
mean = bavg,
168+
seed = seed if (seed is None or uniform_seed) else seed + 3*ii+1
169+
)
170+
if initial_variables is not None:
171+
scope = tf.get_variable_scope().name
172+
w_initializer = tf.constant_initializer(initial_variables[scope+'/matrix_'+str(ii)+name_suffix])
173+
b_initializer = tf.constant_initializer(initial_variables[scope+'/bias_'+str(ii)+name_suffix])
174+
w = tf.get_variable('matrix_'+str(ii)+name_suffix,
156175
[outputs_size[ii - 1], outputs_size[ii]],
157176
precision,
158-
tf.random_normal_initializer(
159-
stddev=stddev/np.sqrt(outputs_size[ii]+outputs_size[ii-1]),
160-
seed = seed if (seed is None or uniform_seed) else seed + ii*3+0
161-
),
177+
w_initializer,
162178
trainable = trainable)
163179
variable_summaries(w, 'matrix_'+str(ii)+name_suffix)
164180

165181
b = tf.get_variable('bias_'+str(ii)+name_suffix,
166182
[1, outputs_size[ii]],
167183
precision,
168-
tf.random_normal_initializer(
169-
stddev=stddev,
170-
mean = bavg,
171-
seed = seed if (seed is None or uniform_seed) else seed + 3*ii+1
172-
),
184+
b_initializer,
173185
trainable = trainable)
174186
variable_summaries(b, 'bias_'+str(ii)+name_suffix)
175187

176188
hidden = tf.reshape(activation_fn(tf.matmul(xx, w) + b), [-1, outputs_size[ii]])
177189
if resnet_dt :
190+
idt_initializer = tf.random_normal_initializer(
191+
stddev=0.001,
192+
mean = 1.0,
193+
seed = seed if (seed is None or uniform_seed) else seed + 3*ii+2
194+
)
195+
if initial_variables is not None:
196+
scope = tf.get_variable_scope().name
197+
idt_initializer = tf.constant_initializer(initial_variables[scope+'/idt_'+str(ii)+name_suffix])
178198
idt = tf.get_variable('idt_'+str(ii)+name_suffix,
179199
[1, outputs_size[ii]],
180200
precision,
181-
tf.random_normal_initializer(
182-
stddev=0.001,
183-
mean = 1.0,
184-
seed = seed if (seed is None or uniform_seed) else seed + 3*ii+2
185-
),
201+
idt_initializer,
186202
trainable = trainable)
187203
variable_summaries(idt, 'idt_'+str(ii)+name_suffix)
188204

0 commit comments

Comments
 (0)