Skip to content

Commit ed20998

Browse files
authored
use TF's built-in method to get numpy dtype (#1035)
* use TF's built-in method to get numpy dtype I got a way to get the numpy type from a int. Take an example ```py >>> tf.dtypes.as_dtype(19).as_numpy_dtype <class 'numpy.float16'> ``` `PRECISION_MAPPING` is not used any more, as it's actually not a public API. By the way, it also supports `str` ```py >>> tf.dtypes.as_dtype("float16") tf.float16 ``` * sadly only `tf.as_dtype` is supported in TF 1.8
1 parent 602760e commit ed20998

File tree

3 files changed

+4
-12
lines changed

3 files changed

+4
-12
lines changed

deepmd/common.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,6 @@
4242
"float64": tf.float64,
4343
}
4444

45-
PRECISION_MAPPING: Dict[int, type] = {
46-
1: np.float32,
47-
2: np.float64,
48-
19: np.float16,
49-
}
50-
5145

5246
def gelu(x: tf.Tensor) -> tf.Tensor:
5347
"""Gaussian Error Linear Unit.

deepmd/entrypoints/transfer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from typing import Dict, Optional, Sequence, Tuple
44
from deepmd.env import tf
5-
from deepmd.common import PRECISION_MAPPING
65
import re
76
import numpy as np
87
import logging
@@ -121,8 +120,8 @@ def transform_graph(raw_graph: tf.Graph, old_graph: tf.Graph) -> tf.Graph:
121120

122121
check_dim(raw_graph_node, old_graph_node, node.name)
123122
tensor_shape = [dim.size for dim in raw_node.tensor_shape.dim]
124-
old_graph_dtype = PRECISION_MAPPING[old_node.dtype]
125-
raw_graph_dtype = PRECISION_MAPPING[raw_node.dtype]
123+
old_graph_dtype = tf.as_dtype(old_node.dtype).as_numpy_dtype
124+
raw_graph_dtype = tf.as_dtype(raw_node.dtype).as_numpy_dtype
126125
log.info(
127126
f"{node.name} is passed from old graph({old_graph_dtype}) "
128127
f"to raw graph({raw_graph_dtype})"

deepmd/utils/graph.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import numpy as np
33
from typing import Tuple, Dict
44
from deepmd.env import tf
5-
from deepmd.common import PRECISION_MAPPING
65
from deepmd.utils.sess import run_sess
76
from deepmd.utils.errors import GraphWithoutTensorError
87

@@ -174,7 +173,7 @@ def get_embedding_net_variables_from_graph_def(graph_def : tf.GraphDef) -> Dict:
174173
embedding_net_nodes = get_embedding_net_nodes_from_graph_def(graph_def)
175174
for item in embedding_net_nodes:
176175
node = embedding_net_nodes[item]
177-
dtype = PRECISION_MAPPING[node.dtype]
176+
dtype = tf.as_dtype(node.dtype).as_numpy_dtype
178177
tensor_shape = tf.TensorShape(node.tensor_shape).as_list()
179178
if (len(tensor_shape) != 1) or (tensor_shape[0] != 1):
180179
tensor_value = np.frombuffer(node.tensor_content)
@@ -262,7 +261,7 @@ def get_fitting_net_variables_from_graph_def(graph_def : tf.GraphDef) -> Dict:
262261
fitting_net_nodes = get_fitting_net_nodes_from_graph_def(graph_def)
263262
for item in fitting_net_nodes:
264263
node = fitting_net_nodes[item]
265-
dtype= PRECISION_MAPPING[node.dtype]
264+
dtype= tf.as_dtype(node.dtype).as_numpy_dtype
266265
tensor_shape = tf.TensorShape(node.tensor_shape).as_list()
267266
if (len(tensor_shape) != 1) or (tensor_shape[0] != 1):
268267
tensor_value = np.frombuffer(node.tensor_content)

0 commit comments

Comments
 (0)