Skip to content

Commit e2fc5e3

Browse files
authored
make compress work for hybrid descriptor composed of se_e2_a (#1094)
* make compress work for hybrid descriptor composed of se_e2_a * fix `get_embedding_net_nodes_from_graph_def` * fix lint warning
1 parent aab124f commit e2fc5e3

File tree

7 files changed

+72
-29
lines changed

7 files changed

+72
-29
lines changed

deepmd/descriptor/descriptor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,8 @@ def enable_compression(self,
188188
table_extrapolate: float = 5.,
189189
table_stride_1: float = 0.01,
190190
table_stride_2: float = 0.1,
191-
check_frequency: int = -1
191+
check_frequency: int = -1,
192+
suffix: str = "",
192193
) -> None:
193194
"""
194195
Reveive the statisitcs (distance, max_nbor_size and env_mat_range) of the
@@ -208,13 +209,15 @@ def enable_compression(self,
208209
The uniform stride of the second table
209210
check_frequency : int, default: -1
210211
The overflow check frequency
212+
suffix : str, optional
213+
The suffix of the scope
211214
212215
Notes
213216
-----
214217
This method is called by others when the descriptor supported compression.
215218
"""
216219
raise NotImplementedError(
217-
"Descriptor %s doesn't support compression!" % self.__name__)
220+
"Descriptor %s doesn't support compression!" % type(self).__name__)
218221

219222
@abstractmethod
220223
def prod_force_virial(self,

deepmd/descriptor/hybrid.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,3 +220,36 @@ def prod_force_virial(self,
220220
virial += vv
221221
atom_virial += av
222222
return force, virial, atom_virial
223+
224+
def enable_compression(self,
225+
min_nbor_dist: float,
226+
model_file: str = 'frozon_model.pb',
227+
table_extrapolate: float = 5.,
228+
table_stride_1: float = 0.01,
229+
table_stride_2: float = 0.1,
230+
check_frequency: int = -1,
231+
suffix: str = ""
232+
) -> None:
233+
"""
234+
Reveive the statisitcs (distance, max_nbor_size and env_mat_range) of the
235+
training data.
236+
237+
Parameters
238+
----------
239+
min_nbor_dist : float
240+
The nearest distance between atoms
241+
model_file : str, default: 'frozon_model.pb'
242+
The original frozen model, which will be compressed by the program
243+
table_extrapolate : float, default: 5.
244+
The scale of model extrapolation
245+
table_stride_1 : float, default: 0.01
246+
The uniform stride of the first table
247+
table_stride_2 : float, default: 0.1
248+
The uniform stride of the second table
249+
check_frequency : int, default: -1
250+
The overflow check frequency
251+
suffix : str, optional
252+
The suffix of the scope
253+
"""
254+
for idx, ii in enumerate(self.descrpt_list):
255+
ii.enable_compression(min_nbor_dist, model_file, table_extrapolate, table_stride_1, table_stride_2, check_frequency, suffix=f"{suffix}_{idx}")

deepmd/descriptor/se_a.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,8 @@ def enable_compression(self,
297297
table_extrapolate : float = 5,
298298
table_stride_1 : float = 0.01,
299299
table_stride_2 : float = 0.1,
300-
check_frequency : int = -1
300+
check_frequency : int = -1,
301+
suffix : str = "",
301302
) -> None:
302303
"""
303304
Reveive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data.
@@ -316,10 +317,15 @@ def enable_compression(self,
316317
The uniform stride of the second table
317318
check_frequency
318319
The overflow check frequency
320+
suffix : str, optional
321+
The suffix of the scope
319322
"""
323+
assert (
324+
not self.filter_resnet_dt
325+
), "Model compression error: descriptor resnet_dt must be false!"
320326
self.compress = True
321327
self.table = DPTabulate(
322-
model_file, self.type_one_side, self.exclude_types, self.compress_activation_fn)
328+
model_file, self.type_one_side, self.exclude_types, self.compress_activation_fn, suffix=suffix)
323329
self.table_config = [table_extrapolate, table_stride_1, table_stride_2, check_frequency]
324330
self.lower, self.upper \
325331
= self.table.build(min_nbor_dist,
@@ -328,8 +334,8 @@ def enable_compression(self,
328334
table_stride_2)
329335

330336
graph, _ = load_graph_def(model_file)
331-
self.davg = get_tensor_by_name_from_graph(graph, 'descrpt_attr/t_avg')
332-
self.dstd = get_tensor_by_name_from_graph(graph, 'descrpt_attr/t_std')
337+
self.davg = get_tensor_by_name_from_graph(graph, 'descrpt_attr%s/t_avg' % suffix)
338+
self.dstd = get_tensor_by_name_from_graph(graph, 'descrpt_attr%s/t_std' % suffix)
333339

334340

335341

deepmd/entrypoints/compress.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,7 @@ def compress(
105105
jdata = normalize(jdata)
106106

107107
# check the descriptor info of the input file
108-
assert (
109-
jdata["model"]["descriptor"]["type"] == "se_a" or jdata["model"]["descriptor"]["type"] == "se_e2_a"
110-
), "Model compression error: descriptor type must be se_a or se_e2_a!"
111-
assert (
112-
jdata["model"]["descriptor"]["resnet_dt"] is False
113-
), "Model compression error: descriptor resnet_dt must be false!"
108+
# move to the specific Descriptor class
114109

115110
# stage 1: training or refining the model with tabulation
116111
log.info("\n\n")

deepmd/train/trainer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,6 @@ def build (self,
325325
name = 'train_attr/max_nbor_size',
326326
dtype = GLOBAL_TF_FLOAT_PRECISION)
327327
else :
328-
assert 'rcut' in self.descrpt_param, "Error: descriptor must have attr rcut!"
329328
self.descrpt.enable_compression(self.model_param['compress']["min_nbor_dist"], self.model_param['compress']['model_file'], self.model_param['compress']['table_config'][0], self.model_param['compress']['table_config'][1], self.model_param['compress']['table_config'][2], self.model_param['compress']['table_config'][3])
330329
self.fitting.init_variables(get_fitting_net_variables(self.model_param['compress']['model_file']))
331330

deepmd/utils/graph.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,22 +112,24 @@ def get_tensor_by_type(node,
112112
return tensor
113113

114114

115-
def get_embedding_net_nodes_from_graph_def(graph_def: tf.GraphDef) -> Dict:
115+
def get_embedding_net_nodes_from_graph_def(graph_def: tf.GraphDef, suffix: str = "") -> Dict:
116116
"""
117117
Get the embedding net nodes with the given tf.GraphDef object
118118
119119
Parameters
120120
----------
121121
graph_def
122122
The input tf.GraphDef object
123+
suffix : str, optional
124+
The scope suffix
123125
124126
Returns
125127
----------
126128
Dict
127129
The embedding net nodes within the given tf.GraphDef object
128130
"""
129131
embedding_net_nodes = {}
130-
embedding_net_pattern = "filter_type_\d+/matrix_\d+_\d+|filter_type_\d+/bias_\d+_\d+|filter_type_\d+/idt_\d+_\d+|filter_type_all/matrix_\d+_\d+|filter_type_all/bias_\d+_\d+|filter_type_all/idt_\d+_\d"
132+
embedding_net_pattern = f"filter_type_\d+{suffix}/matrix_\d+_\d+|filter_type_\d+{suffix}/bias_\d+_\d+|filter_type_\d+{suffix}/idt_\d+_\d+|filter_type_all{suffix}/matrix_\d+_\d+|filter_type_all{suffix}/bias_\d+_\d+|filter_type_all{suffix}/idt_\d+_\d"
131133
for node in graph_def.node:
132134
if re.fullmatch(embedding_net_pattern, node.name) != None:
133135
embedding_net_nodes[node.name] = node.attr["value"].tensor

deepmd/utils/tabulate.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,29 +34,34 @@ class DPTabulate():
3434
For example, `[[0, 1]]` means no interaction between type 0 and type 1.
3535
activation_function
3636
The activation function in the embedding net. Supported options are {"tanh","gelu"} in common.ACTIVATION_FN_DICT.
37+
suffix : str, optional
38+
The suffix of the scope
3739
"""
3840
def __init__(self,
3941
model_file : str,
4042
type_one_side : bool = False,
4143
exclude_types : List[List[int]] = [],
42-
activation_fn : Callable[[tf.Tensor], tf.Tensor] = tf.nn.tanh) -> None:
44+
activation_fn : Callable[[tf.Tensor], tf.Tensor] = tf.nn.tanh,
45+
suffix : str = "",
46+
) -> None:
4347
"""
4448
Constructor
4549
"""
4650

4751
self.model_file = model_file
4852
self.type_one_side = type_one_side
4953
self.exclude_types = exclude_types
54+
self.suffix = suffix
5055
if self.type_one_side and len(self.exclude_types) != 0:
51-
raise RunTimeError('"type_one_side" is not compatible with "exclude_types"')
56+
raise RuntimeError('"type_one_side" is not compatible with "exclude_types"')
5257

5358
# functype
5459
if activation_fn == ACTIVATION_FN_DICT["tanh"]:
5560
self.functype = 1
5661
elif activation_fn == ACTIVATION_FN_DICT["gelu"]:
5762
self.functype = 2
5863
else:
59-
raise RunTimeError("Unknown actication function type!")
64+
raise RuntimeError("Unknown actication function type!")
6065
self.activation_fn = activation_fn
6166

6267
self.graph, self.graph_def = load_graph_def(self.model_file)
@@ -72,15 +77,15 @@ def __init__(self,
7277
self.sel_a = self.graph.get_operation_by_name('DescrptSeA').get_attr('sel_a')
7378
self.descrpt = self.graph.get_operation_by_name ('DescrptSeA')
7479

75-
self.davg = get_tensor_by_name_from_graph(self.graph, 'descrpt_attr/t_avg')
76-
self.dstd = get_tensor_by_name_from_graph(self.graph, 'descrpt_attr/t_std')
80+
self.davg = get_tensor_by_name_from_graph(self.graph, f'descrpt_attr{self.suffix}/t_avg')
81+
self.dstd = get_tensor_by_name_from_graph(self.graph, f'descrpt_attr{self.suffix}/t_std')
7782
self.ntypes = get_tensor_by_name_from_graph(self.graph, 'descrpt_attr/ntypes')
7883

7984

8085
self.rcut = self.descrpt.get_attr('rcut_r')
8186
self.rcut_smth = self.descrpt.get_attr('rcut_r_smth')
8287

83-
self.embedding_net_nodes = get_embedding_net_nodes_from_graph_def(self.graph_def)
88+
self.embedding_net_nodes = get_embedding_net_nodes_from_graph_def(self.graph_def, suffix=self.suffix)
8489

8590
for tt in self.exclude_types:
8691
if (tt[0] not in range(self.ntypes)) or (tt[1] not in range(self.ntypes)):
@@ -174,14 +179,14 @@ def _get_bias(self):
174179
bias["layer_" + str(layer)] = []
175180
if self.type_one_side:
176181
for ii in range(0, self.ntypes):
177-
tensor_value = np.frombuffer (self.embedding_net_nodes["filter_type_all/bias_" + str(layer) + "_" + str(ii)].tensor_content)
178-
tensor_shape = tf.TensorShape(self.embedding_net_nodes["filter_type_all/bias_" + str(layer) + "_" + str(ii)].tensor_shape).as_list()
182+
tensor_value = np.frombuffer (self.embedding_net_nodes[f"filter_type_all{self.suffix}/bias_{layer}_{ii}"].tensor_content)
183+
tensor_shape = tf.TensorShape(self.embedding_net_nodes[f"filter_type_all{self.suffix}/bias_{layer}_{ii}"].tensor_shape).as_list()
179184
bias["layer_" + str(layer)].append(np.reshape(tensor_value, tensor_shape))
180185
else:
181186
for ii in range(0, self.ntypes * self.ntypes):
182187
if (ii // self.ntypes, int(ii % self.ntypes)) not in self.exclude_types:
183-
tensor_value = np.frombuffer(self.embedding_net_nodes["filter_type_" + str(ii // self.ntypes) + "/bias_" + str(layer) + "_" + str(int(ii % self.ntypes))].tensor_content)
184-
tensor_shape = tf.TensorShape(self.embedding_net_nodes["filter_type_" + str(ii // self.ntypes) + "/bias_" + str(layer) + "_" + str(int(ii % self.ntypes))].tensor_shape).as_list()
188+
tensor_value = np.frombuffer(self.embedding_net_nodes[f"filter_type_{ii // self.ntypes}{self.suffix}/bias_{layer}_{ii % self.ntypes}"].tensor_content)
189+
tensor_shape = tf.TensorShape(self.embedding_net_nodes[f"filter_type_{ii // self.ntypes}{self.suffix}/bias_{layer}_{ii % self.ntypes}"].tensor_shape).as_list()
185190
bias["layer_" + str(layer)].append(np.reshape(tensor_value, tensor_shape))
186191
else:
187192
bias["layer_" + str(layer)].append(np.array([]))
@@ -193,14 +198,14 @@ def _get_matrix(self):
193198
matrix["layer_" + str(layer)] = []
194199
if self.type_one_side:
195200
for ii in range(0, self.ntypes):
196-
tensor_value = np.frombuffer (self.embedding_net_nodes["filter_type_all/matrix_" + str(layer) + "_" + str(ii)].tensor_content)
197-
tensor_shape = tf.TensorShape(self.embedding_net_nodes["filter_type_all/matrix_" + str(layer) + "_" + str(ii)].tensor_shape).as_list()
201+
tensor_value = np.frombuffer (self.embedding_net_nodes[f"filter_type_all{self.suffix}/matrix_{layer}_{ii}"].tensor_content)
202+
tensor_shape = tf.TensorShape(self.embedding_net_nodes[f"filter_type_all{self.suffix}/matrix_{layer}_{ii}"].tensor_shape).as_list()
198203
matrix["layer_" + str(layer)].append(np.reshape(tensor_value, tensor_shape))
199204
else:
200205
for ii in range(0, self.ntypes * self.ntypes):
201206
if (ii // self.ntypes, int(ii % self.ntypes)) not in self.exclude_types:
202-
tensor_value = np.frombuffer(self.embedding_net_nodes["filter_type_" + str(ii // self.ntypes) + "/matrix_" + str(layer) + "_" + str(int(ii % self.ntypes))].tensor_content)
203-
tensor_shape = tf.TensorShape(self.embedding_net_nodes["filter_type_" + str(ii // self.ntypes) + "/matrix_" + str(layer) + "_" + str(int(ii % self.ntypes))].tensor_shape).as_list()
207+
tensor_value = np.frombuffer(self.embedding_net_nodes[f"filter_type_{ii // self.ntypes}{self.suffix}/matrix_{layer}_{ii % self.ntypes}"].tensor_content)
208+
tensor_shape = tf.TensorShape(self.embedding_net_nodes[f"filter_type_{ii // self.ntypes}{self.suffix}/matrix_{layer}_{ii % self.ntypes}"].tensor_shape).as_list()
204209
matrix["layer_" + str(layer)].append(np.reshape(tensor_value, tensor_shape))
205210
else:
206211
matrix["layer_" + str(layer)].append(np.array([]))

0 commit comments

Comments
 (0)