Skip to content

Commit 904ec11

Browse files
authored
fix bug of single precision model compression (#1110)
1 parent dd3c1de commit 904ec11

File tree

4 files changed

+20
-16
lines changed

4 files changed

+20
-16
lines changed

deepmd/entrypoints/compress.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import logging
66
from typing import Optional
77

8-
from deepmd.env import tf
9-
from deepmd.common import j_loader, GLOBAL_TF_FLOAT_PRECISION
8+
from deepmd.common import j_loader
9+
from deepmd.env import tf, GLOBAL_ENER_FLOAT_PRECISION
1010
from deepmd.utils.argcheck import normalize
1111
from deepmd.utils.compat import updata_deepmd_input
1212
from deepmd.utils.errors import GraphTooLargeError, GraphWithoutTensorError
@@ -89,7 +89,7 @@ def compress(
8989

9090
tf.constant(t_min_nbor_dist,
9191
name = 'train_attr/min_nbor_dist',
92-
dtype = GLOBAL_TF_FLOAT_PRECISION)
92+
dtype = GLOBAL_ENER_FLOAT_PRECISION)
9393
jdata["model"]["compress"] = {}
9494
jdata["model"]["compress"]["type"] = 'se_e2_a'
9595
jdata["model"]["compress"]["compress"] = True

deepmd/entrypoints/train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing import Dict, List, Optional, Any
1111

1212
from deepmd.common import data_requirement, expand_sys_str, j_loader, j_must_have
13-
from deepmd.env import tf, reset_default_tf_session_config, GLOBAL_TF_FLOAT_PRECISION
13+
from deepmd.env import tf, reset_default_tf_session_config, GLOBAL_ENER_FLOAT_PRECISION
1414
from deepmd.infer.data_modifier import DipoleChargeModifier
1515
from deepmd.train.run_options import BUILD, CITATION, WELCOME, RunOptions
1616
from deepmd.train.trainer import DPTrainer
@@ -268,10 +268,10 @@ def get_nbor_stat(jdata, rcut):
268268
# architecture to call neighbor stat
269269
tf.constant(min_nbor_dist,
270270
name = 'train_attr/min_nbor_dist',
271-
dtype = GLOBAL_TF_FLOAT_PRECISION)
271+
dtype = GLOBAL_ENER_FLOAT_PRECISION)
272272
tf.constant(max_nbor_size,
273273
name = 'train_attr/max_nbor_size',
274-
dtype = GLOBAL_TF_FLOAT_PRECISION)
274+
dtype = tf.int32)
275275
return min_nbor_dist, max_nbor_size
276276

277277
def get_sel(jdata, rcut):

deepmd/utils/graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def get_embedding_net_variables_from_graph_def(graph_def : tf.GraphDef) -> Dict:
178178
dtype = tf.as_dtype(node.dtype).as_numpy_dtype
179179
tensor_shape = tf.TensorShape(node.tensor_shape).as_list()
180180
if (len(tensor_shape) != 1) or (tensor_shape[0] != 1):
181-
tensor_value = np.frombuffer(node.tensor_content)
181+
tensor_value = np.frombuffer(node.tensor_content, dtype = tf.as_dtype(node.dtype).as_numpy_dtype)
182182
else:
183183
tensor_value = get_tensor_by_type(node, dtype)
184184
embedding_net_variables[item] = np.reshape(tensor_value, tensor_shape)
@@ -266,7 +266,7 @@ def get_fitting_net_variables_from_graph_def(graph_def : tf.GraphDef) -> Dict:
266266
dtype= tf.as_dtype(node.dtype).as_numpy_dtype
267267
tensor_shape = tf.TensorShape(node.tensor_shape).as_list()
268268
if (len(tensor_shape) != 1) or (tensor_shape[0] != 1):
269-
tensor_value = np.frombuffer(node.tensor_content)
269+
tensor_value = np.frombuffer(node.tensor_content, dtype = tf.as_dtype(node.dtype).as_numpy_dtype)
270270
else:
271271
tensor_value = get_tensor_by_type(node, dtype)
272272
fitting_net_variables[item] = np.reshape(tensor_value, tensor_shape)

deepmd/utils/tabulate.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -179,14 +179,16 @@ def _get_bias(self):
179179
bias["layer_" + str(layer)] = []
180180
if self.type_one_side:
181181
for ii in range(0, self.ntypes):
182-
tensor_value = np.frombuffer (self.embedding_net_nodes[f"filter_type_all{self.suffix}/bias_{layer}_{ii}"].tensor_content)
183-
tensor_shape = tf.TensorShape(self.embedding_net_nodes[f"filter_type_all{self.suffix}/bias_{layer}_{ii}"].tensor_shape).as_list()
182+
node = self.embedding_net_nodes[f"filter_type_all{self.suffix}/bias_{layer}_{ii}"]
183+
tensor_value = np.frombuffer (node.tensor_content, dtype = tf.as_dtype(node.dtype).as_numpy_dtype)
184+
tensor_shape = tf.TensorShape(node.tensor_shape).as_list()
184185
bias["layer_" + str(layer)].append(np.reshape(tensor_value, tensor_shape))
185186
else:
186187
for ii in range(0, self.ntypes * self.ntypes):
187188
if (ii // self.ntypes, int(ii % self.ntypes)) not in self.exclude_types:
188-
tensor_value = np.frombuffer(self.embedding_net_nodes[f"filter_type_{ii // self.ntypes}{self.suffix}/bias_{layer}_{ii % self.ntypes}"].tensor_content)
189-
tensor_shape = tf.TensorShape(self.embedding_net_nodes[f"filter_type_{ii // self.ntypes}{self.suffix}/bias_{layer}_{ii % self.ntypes}"].tensor_shape).as_list()
189+
node = self.embedding_net_nodes[f"filter_type_{ii // self.ntypes}{self.suffix}/bias_{layer}_{ii % self.ntypes}"]
190+
tensor_value = np.frombuffer(node.tensor_content, dtype = tf.as_dtype(node.dtype).as_numpy_dtype)
191+
tensor_shape = tf.TensorShape(node.tensor_shape).as_list()
190192
bias["layer_" + str(layer)].append(np.reshape(tensor_value, tensor_shape))
191193
else:
192194
bias["layer_" + str(layer)].append(np.array([]))
@@ -198,14 +200,16 @@ def _get_matrix(self):
198200
matrix["layer_" + str(layer)] = []
199201
if self.type_one_side:
200202
for ii in range(0, self.ntypes):
201-
tensor_value = np.frombuffer (self.embedding_net_nodes[f"filter_type_all{self.suffix}/matrix_{layer}_{ii}"].tensor_content)
202-
tensor_shape = tf.TensorShape(self.embedding_net_nodes[f"filter_type_all{self.suffix}/matrix_{layer}_{ii}"].tensor_shape).as_list()
203+
node = self.embedding_net_nodes[f"filter_type_all{self.suffix}/matrix_{layer}_{ii}"]
204+
tensor_value = np.frombuffer (node.tensor_content, dtype = tf.as_dtype(node.dtype).as_numpy_dtype)
205+
tensor_shape = tf.TensorShape(node.tensor_shape).as_list()
203206
matrix["layer_" + str(layer)].append(np.reshape(tensor_value, tensor_shape))
204207
else:
205208
for ii in range(0, self.ntypes * self.ntypes):
206209
if (ii // self.ntypes, int(ii % self.ntypes)) not in self.exclude_types:
207-
tensor_value = np.frombuffer(self.embedding_net_nodes[f"filter_type_{ii // self.ntypes}{self.suffix}/matrix_{layer}_{ii % self.ntypes}"].tensor_content)
208-
tensor_shape = tf.TensorShape(self.embedding_net_nodes[f"filter_type_{ii // self.ntypes}{self.suffix}/matrix_{layer}_{ii % self.ntypes}"].tensor_shape).as_list()
210+
node = self.embedding_net_nodes[f"filter_type_{ii // self.ntypes}{self.suffix}/matrix_{layer}_{ii % self.ntypes}"]
211+
tensor_value = np.frombuffer(node.tensor_content, dtype = tf.as_dtype(node.dtype).as_numpy_dtype)
212+
tensor_shape = tf.TensorShape(node.tensor_shape).as_list()
209213
matrix["layer_" + str(layer)].append(np.reshape(tensor_value, tensor_shape))
210214
else:
211215
matrix["layer_" + str(layer)].append(np.array([]))

0 commit comments

Comments
 (0)