Skip to content

Commit 8f6dea1

Browse files
authored
Update transform.py
1 parent 288f61d commit 8f6dea1

File tree

1 file changed

+56
-33
lines changed

1 file changed

+56
-33
lines changed

source/train/transform.py

Lines changed: 56 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
from deepmd.env import tf
22
import re
3+
import numpy as np
34
def transform(args):
4-
new_graph = load_graph(args.raw_model)
5+
raw_graph = load_graph(args.raw_model)
56
old_graph = load_graph(args.old_model)
6-
print("%d ops in the raw graph\n%d ops in the old graph" %(len(new_graph.node),len(old_graph.node)))
7-
transform_node = load_data(new_graph,old_graph)
8-
for node in new_graph.node:
9-
if node.name in transform_node:
10-
print("%s is passed from old graph to raw graph" % node.name)
11-
node.attr["value"].tensor.CopyFrom(transform_node[node.name].attr["value"].tensor)
7+
print("%d ops in the raw graph\n%d ops in the old graph" %(len(raw_graph.as_graph_def().node),len(old_graph.as_graph_def().node)))
8+
new_graph_def = transform_graph(raw_graph,old_graph)
129
with tf.gfile.GFile(args.output, mode='wb') as f:
13-
f.write(new_graph.SerializeToString())
10+
f.write(new_graph_def.SerializeToString())
1411
print("the output model is saved in %s" % args.output)
1512

1613
def load_graph(graphName):
@@ -19,30 +16,56 @@ def load_graph(graphName):
1916
graph_def.ParseFromString(f.read())
2017
with tf.Graph().as_default() as graph:
2118
tf.import_graph_def(graph_def,name = "")
22-
return graph_def
23-
24-
def load_data(new_graph,old_graph):
25-
new_graph_node = load_transform_node(new_graph)
26-
old_graph_node = load_transform_node(old_graph)
27-
if len(new_graph_node) != len(old_graph_node):
28-
raise RuntimeError("New graph and original graph has different network structure\n")
29-
for nodeName in old_graph_node.keys():
30-
check_dim(new_graph_node, old_graph_node, nodeName)
31-
check_precision(new_graph_node, old_graph_node, nodeName)
32-
return old_graph_node
33-
34-
35-
def check_precision(new_graph_node, old_graph_node, nodeName):
36-
new_graph_precision = new_graph_node[nodeName].attr["value"].tensor.dtype
37-
old_graph_precision = old_graph_node[nodeName].attr["value"].tensor.dtype
38-
if new_graph_precision != old_graph_precision:
39-
raise RuntimeError("New graph and original graph has different"+nodeName+" precision\n")
40-
41-
def check_dim(new_graph_node, old_graph_node, nodeName):
42-
new_graph_dim = new_graph_node[nodeName].attr["value"].tensor.tensor_shape
43-
old_graph_dim = old_graph_node[nodeName].attr["value"].tensor.tensor_shape
44-
if new_graph_dim != old_graph_dim:
45-
raise RuntimeError("New graph and original graph has different"+nodeName+" dim\n")
19+
return graph
20+
21+
def transform_graph(raw_graph,old_graph):
22+
precision_dict = {\
23+
1:(np.float32, "float32"),\
24+
2:(np.float64, "float64"),\
25+
19:(np.float16, "float16")\
26+
}
27+
old_graph_def = old_graph.as_graph_def()
28+
raw_graph_def = raw_graph.as_graph_def()
29+
raw_graph_node = load_transform_node(raw_graph_def)
30+
old_graph_node = load_transform_node(old_graph_def)
31+
32+
if len(raw_graph_node) != len(old_graph_node):
33+
raise RuntimeError("raw graph and old graph has different network structure")
34+
35+
for node in raw_graph_def.node:
36+
if node.name in raw_graph_node.keys():
37+
if precision_dict[old_graph_node[node.name].dtype][1] == "float16" or precision_dict[raw_graph_node[node.name].dtype][1] == "float16":
38+
raise RuntimeError("float16 conversions not currently supported")
39+
40+
check_dim(raw_graph_node, old_graph_node, node.name)
41+
42+
if re.fullmatch("final_layer_type_\d+/bias",node.name) == None:
43+
tensor_value = np.frombuffer(old_graph_node[node.name].tensor_content,dtype = precision_dict[old_graph_node[node.name].dtype][0])
44+
tensor_value = tensor_value.astype(dtype=precision_dict[raw_graph_node[node.name].dtype][0])
45+
node.attr["value"].tensor.tensor_content = tensor_value.tostring()
46+
47+
else:
48+
if precision_dict[old_graph_node[node.name].dtype][1] == "float64":
49+
tensor_value = (np.array(old_graph_node[node.name].double_val)).astype(precision_dict[raw_graph_node[node.name].dtype][0])
50+
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value,precision_dict[raw_graph_node[node.name].dtype][0], [1])))
51+
52+
elif precision_dict[old_graph_node[node.name].dtype][1] == "float32":
53+
tensor_value = (np.array(old_graph_node[node.name].float_val)).astype(precision_dict[raw_graph_node[node.name].dtype][0])
54+
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value, precision_dict[raw_graph_node[node.name].dtype][0], [1])))
55+
56+
elif precision_dict[old_graph_node[node.name].dtype][1] == "float16":
57+
tensor_value = (np.array(old_graph_node[node.name].half_val)).astype(precision_dict[raw_graph_node[node.name].dtype][0])
58+
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value, precision_dict[raw_graph_node[node.name].dtype][0], [1])))
59+
60+
print("%s is passed from old graph(%s) to raw graph(%s)" % (node.name,precision_dict[old_graph_node[node.name].dtype][1],precision_dict[raw_graph_node[node.name].dtype][1]))
61+
62+
return raw_graph_def
63+
64+
def check_dim(raw_graph_node, old_graph_node, node_name):
65+
raw_graph_dim = raw_graph_node[node_name].tensor_shape
66+
old_graph_dim = old_graph_node[node_name].tensor_shape
67+
if raw_graph_dim != old_graph_dim:
68+
raise RuntimeError("old graph and raw graph has different"+node_name+" dim")
4669

4770

4871
def load_transform_node(graph):
@@ -59,5 +82,5 @@ def load_transform_node(graph):
5982
"
6083
for node in graph.node:
6184
if re.fullmatch(transform_node_pattern,node.name) != None:
62-
transform_node[node.name] = node
85+
transform_node[node.name] = node.attr["value"].tensor
6386
return transform_node

0 commit comments

Comments
 (0)