Skip to content

Commit 43fd8bf

Browse files
authored
Update transform.py
1 parent 0a9646c commit 43fd8bf

File tree

1 file changed

+68
-18
lines changed

1 file changed

+68
-18
lines changed

source/train/transform.py

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,21 @@
11
from deepmd.env import tf
22
import re
33
import numpy as np
4+
5+
def convertNumber(number):
6+
binary = bin(number).replace("0b", "").zfill(16)
7+
sign = int(binary[0]) * (-2) + 1
8+
exp = int(binary[1:6], 2)
9+
frac = (int(binary[6:], 2) + 2 ** 10) * (2 ** -10)
10+
return sign * (2 ** (exp - 15)) * frac
11+
12+
13+
def convertMatrix(matrix, shape):
14+
matrix = matrix.flatten()
15+
tmp = np.array([convertNumber(matrix[i]) for i in range(len(matrix))])
16+
return tmp.reshape(shape)
17+
18+
419
def transform(args):
520
raw_graph = load_graph(args.raw_model)
621
old_graph = load_graph(args.old_model)
@@ -34,31 +49,66 @@ def transform_graph(raw_graph,old_graph):
3449

3550
for node in raw_graph_def.node:
3651
if node.name in raw_graph_node.keys():
52+
"""
3753
if precision_dict[old_graph_node[node.name].dtype][1] == "float16" or precision_dict[raw_graph_node[node.name].dtype][1] == "float16":
3854
raise RuntimeError("float16 conversions not currently supported")
55+
"""
3956

4057
check_dim(raw_graph_node, old_graph_node, node.name)
58+
old_graph_dtype = precision_dict[old_graph_node[node.name].dtype]
59+
raw_graph_dtype = precision_dict[raw_graph_node[node.name].dtype]
60+
print("%s is passed from old graph(%s) to raw graph(%s)" % (node.name, old_graph_dtype[1],raw_graph_dtype[1]))
61+
62+
if raw_graph_dtype[1] == "float16":
63+
if old_graph_dtype[1] == "float64" or old_graph_dtype[1] == "float32":
64+
if re.fullmatch("final_layer_type_\d+/bias", node.name) == None:
65+
tensor_value = np.frombuffer(old_graph_node[node.name].tensor_content, dtype=old_graph_dtype[0])
66+
tensor_value = tensor_value.astype(np.float16)
67+
tensor_shape = [dim.size for dim in raw_graph_node[node.name].tensor_shape.dim]
68+
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value, tf.float16, tensor_shape)))
4169

42-
if re.fullmatch("final_layer_type_\d+/bias",node.name) == None:
43-
tensor_value = np.frombuffer(old_graph_node[node.name].tensor_content,dtype = precision_dict[old_graph_node[node.name].dtype][0])
44-
tensor_value = tensor_value.astype(dtype=precision_dict[raw_graph_node[node.name].dtype][0])
45-
node.attr["value"].tensor.tensor_content = tensor_value.tostring()
70+
else:
71+
if old_graph_dtype[1] == "float64":
72+
tensor_value = (np.array(old_graph_node[node.name].double_val)).astype(np.float16)
73+
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value,tf.float16, [1])))
4674

47-
else:
48-
if precision_dict[old_graph_node[node.name].dtype][1] == "float64":
49-
tensor_value = (np.array(old_graph_node[node.name].double_val)).astype(precision_dict[raw_graph_node[node.name].dtype][0])
50-
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value,precision_dict[raw_graph_node[node.name].dtype][0], [1])))
51-
52-
elif precision_dict[old_graph_node[node.name].dtype][1] == "float32":
53-
tensor_value = (np.array(old_graph_node[node.name].float_val)).astype(precision_dict[raw_graph_node[node.name].dtype][0])
54-
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value, precision_dict[raw_graph_node[node.name].dtype][0], [1])))
55-
56-
elif precision_dict[old_graph_node[node.name].dtype][1] == "float16":
57-
tensor_value = (np.array(old_graph_node[node.name].half_val)).astype(precision_dict[raw_graph_node[node.name].dtype][0])
58-
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value, precision_dict[raw_graph_node[node.name].dtype][0], [1])))
75+
elif old_graph_dtype[1] == "float32":
76+
tensor_value = (np.array(old_graph_node[node.name].float_val)).astype(np.float16)
77+
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value,tf.float16, [1])))
78+
79+
elif old_graph_dtype[1] == "float16":
80+
tensor_shape = [dim.size for dim in raw_graph_node[node.name].tensor_shape.dim]
81+
tensor_value = convertMatrix(np.array(old_graph_node[node.name].half_val), tensor_shape)
82+
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value, tf.float16, tensor_value.shape)))
5983

60-
print("%s is passed from old graph(%s) to raw graph(%s)" % (node.name,precision_dict[old_graph_node[node.name].dtype][1],precision_dict[raw_graph_node[node.name].dtype][1]))
61-
84+
elif raw_graph_dtype[1] == "float64" or raw_graph_dtype[1] == "float32":
85+
if old_graph_dtype[1] == "float64" or old_graph_dtype[1] == "float32":
86+
if re.fullmatch("final_layer_type_\d+/bias", node.name) == None:
87+
tensor_value = np.frombuffer(old_graph_node[node.name].tensor_content,dtype = old_graph_dtype[0])
88+
tensor_value = tensor_value.astype(dtype=raw_graph_dtype[0])
89+
node.attr["value"].tensor.tensor_content = tensor_value.tostring()
90+
91+
else:
92+
if old_graph_dtype[1] == "float64":
93+
tensor_value = (np.array(old_graph_node[node.name].double_val)).astype(raw_graph_dtype[0])
94+
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value,raw_graph_dtype[0], [1])))
95+
96+
elif old_graph_dtype[1] == "float32":
97+
tensor_value = (np.array(old_graph_node[node.name].float_val)).astype(raw_graph_dtype[0])
98+
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value,raw_graph_dtype[0], [1])))
99+
100+
elif old_graph_dtype[1] == "float16":
101+
if re.fullmatch("final_layer_type_\d+/bias", node.name) == None:
102+
tensor_shape = [dim.size for dim in raw_graph_node[node.name].tensor_shape.dim]
103+
tensor_value = convertMatrix(np.array(old_graph_node[node.name].half_val), tensor_shape)
104+
tensor_value = tensor_value.astype(raw_graph_dtype[0])
105+
node.attr["value"].tensor.tensor_content = tensor_value.tostring()
106+
else:
107+
tensor_shape = [dim.size for dim in raw_graph_node[node.name].tensor_shape.dim]
108+
tensor_value = convertMatrix(np.array(old_graph_node[node.name].half_val), tensor_shape)
109+
tensor_value = tensor_value.astype(raw_graph_dtype[0])
110+
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value,raw_graph_dtype[0], tensor_value.shape)))
111+
62112
return raw_graph_def
63113

64114
def check_dim(raw_graph_node, old_graph_node, node_name):

0 commit comments

Comments
 (0)