Skip to content

Commit 2a71494

Browse files
authored
fix bug of single precision transfer (#1111)
1 parent 904ec11 commit 2a71494

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

deepmd/entrypoints/transfer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def transform_graph(raw_graph: tf.Graph, old_graph: tf.Graph) -> tf.Graph:
130130
if raw_graph_dtype == np.float16:
131131
if old_graph_dtype == np.float64 or old_graph_dtype == np.float32:
132132
if (len(tensor_shape) != 1) or (tensor_shape[0] != 1):
133-
tensor = np.frombuffer(old_node.tensor_content).astype(raw_graph_dtype)
133+
tensor = np.frombuffer(old_node.tensor_content, dtype = raw_graph_dtype)
134134
cp_attr.from_array(tensor, tf.float16, shape = tensor_shape)
135135
else:
136136
tensor = load_tensor(old_node, old_graph_dtype, raw_graph_dtype)
@@ -143,7 +143,7 @@ def transform_graph(raw_graph: tf.Graph, old_graph: tf.Graph) -> tf.Graph:
143143
elif raw_graph_dtype == np.float64 or raw_graph_dtype == np.float32:
144144
if old_graph_dtype == np.float64 or old_graph_dtype == np.float32:
145145
if (len(tensor_shape) != 1) or (tensor_shape[0] != 1):
146-
tensor = np.frombuffer(old_node.tensor_content).astype(raw_graph_dtype)
146+
tensor = np.frombuffer(old_node.tensor_content, dtype = raw_graph_dtype)
147147
cp_attr.from_str(tensor)
148148
else:
149149
tensor = load_tensor(old_node, old_graph_dtype, raw_graph_dtype)

0 commit comments

Comments
 (0)