11from deepmd .env import tf
22import re
3+ import numpy as np
34def transform (args ):
4- new_graph = load_graph (args .raw_model )
5+ raw_graph = load_graph (args .raw_model )
56 old_graph = load_graph (args .old_model )
6- print ("%d ops in the raw graph\n %d ops in the old graph" % (len (new_graph .node ),len (old_graph .node )))
7- transform_node = load_data (new_graph ,old_graph )
8- for node in new_graph .node :
9- if node .name in transform_node :
10- print ("%s is passed from old graph to raw graph" % node .name )
11- node .attr ["value" ].tensor .CopyFrom (transform_node [node .name ].attr ["value" ].tensor )
7+ print ("%d ops in the raw graph\n %d ops in the old graph" % (len (raw_graph .as_graph_def ().node ),len (old_graph .as_graph_def ().node )))
8+ new_graph_def = transform_graph (raw_graph ,old_graph )
129 with tf .gfile .GFile (args .output , mode = 'wb' ) as f :
13- f .write (new_graph .SerializeToString ())
10+ f .write (new_graph_def .SerializeToString ())
1411 print ("the output model is saved in %s" % args .output )
1512
1613def load_graph (graphName ):
@@ -19,30 +16,56 @@ def load_graph(graphName):
1916 graph_def .ParseFromString (f .read ())
2017 with tf .Graph ().as_default () as graph :
2118 tf .import_graph_def (graph_def ,name = "" )
22- return graph_def
23-
24- def load_data (new_graph ,old_graph ):
25- new_graph_node = load_transform_node (new_graph )
26- old_graph_node = load_transform_node (old_graph )
27- if len (new_graph_node ) != len (old_graph_node ):
28- raise RuntimeError ("New graph and original graph has different network structure\n " )
29- for nodeName in old_graph_node .keys ():
30- check_dim (new_graph_node , old_graph_node , nodeName )
31- check_precision (new_graph_node , old_graph_node , nodeName )
32- return old_graph_node
33-
34-
35- def check_precision (new_graph_node , old_graph_node , nodeName ):
36- new_graph_precision = new_graph_node [nodeName ].attr ["value" ].tensor .dtype
37- old_graph_precision = old_graph_node [nodeName ].attr ["value" ].tensor .dtype
38- if new_graph_precision != old_graph_precision :
39- raise RuntimeError ("New graph and original graph has different" + nodeName + " precision\n " )
40-
41- def check_dim (new_graph_node , old_graph_node , nodeName ):
42- new_graph_dim = new_graph_node [nodeName ].attr ["value" ].tensor .tensor_shape
43- old_graph_dim = old_graph_node [nodeName ].attr ["value" ].tensor .tensor_shape
44- if new_graph_dim != old_graph_dim :
45- raise RuntimeError ("New graph and original graph has different" + nodeName + " dim\n " )
19+ return graph
20+
21+ def transform_graph (raw_graph ,old_graph ):
22+ precision_dict = {\
23+ 1 :(np .float32 , "float32" ),\
24+ 2 :(np .float64 , "float64" ),\
25+ 19 :(np .float16 , "float16" )\
26+ }
27+ old_graph_def = old_graph .as_graph_def ()
28+ raw_graph_def = raw_graph .as_graph_def ()
29+ raw_graph_node = load_transform_node (raw_graph_def )
30+ old_graph_node = load_transform_node (old_graph_def )
31+
32+ if len (raw_graph_node ) != len (old_graph_node ):
33+ raise RuntimeError ("raw graph and old graph has different network structure" )
34+
35+ for node in raw_graph_def .node :
36+ if node .name in raw_graph_node .keys ():
37+ if precision_dict [old_graph_node [node .name ].dtype ][1 ] == "float16" or precision_dict [raw_graph_node [node .name ].dtype ][1 ] == "float16" :
38+ raise RuntimeError ("float16 conversions not currently supported" )
39+
40+ check_dim (raw_graph_node , old_graph_node , node .name )
41+
42+ if re .fullmatch ("final_layer_type_\d+/bias" ,node .name ) == None :
43+ tensor_value = np .frombuffer (old_graph_node [node .name ].tensor_content ,dtype = precision_dict [old_graph_node [node .name ].dtype ][0 ])
44+ tensor_value = tensor_value .astype (dtype = precision_dict [raw_graph_node [node .name ].dtype ][0 ])
45+ node .attr ["value" ].tensor .tensor_content = tensor_value .tostring ()
46+
47+ else :
48+ if precision_dict [old_graph_node [node .name ].dtype ][1 ] == "float64" :
49+ tensor_value = (np .array (old_graph_node [node .name ].double_val )).astype (precision_dict [raw_graph_node [node .name ].dtype ][0 ])
50+ node .attr ["value" ].CopyFrom (tf .AttrValue (tensor = tf .make_tensor_proto (tensor_value ,precision_dict [raw_graph_node [node .name ].dtype ][0 ], [1 ])))
51+
52+ elif precision_dict [old_graph_node [node .name ].dtype ][1 ] == "float32" :
53+ tensor_value = (np .array (old_graph_node [node .name ].float_val )).astype (precision_dict [raw_graph_node [node .name ].dtype ][0 ])
54+ node .attr ["value" ].CopyFrom (tf .AttrValue (tensor = tf .make_tensor_proto (tensor_value , precision_dict [raw_graph_node [node .name ].dtype ][0 ], [1 ])))
55+
56+ elif precision_dict [old_graph_node [node .name ].dtype ][1 ] == "float16" :
57+ tensor_value = (np .array (old_graph_node [node .name ].half_val )).astype (precision_dict [raw_graph_node [node .name ].dtype ][0 ])
58+ node .attr ["value" ].CopyFrom (tf .AttrValue (tensor = tf .make_tensor_proto (tensor_value , precision_dict [raw_graph_node [node .name ].dtype ][0 ], [1 ])))
59+
60+ print ("%s is passed from old graph(%s) to raw graph(%s)" % (node .name ,precision_dict [old_graph_node [node .name ].dtype ][1 ],precision_dict [raw_graph_node [node .name ].dtype ][1 ]))
61+
62+ return raw_graph_def
63+
64+ def check_dim (raw_graph_node , old_graph_node , node_name ):
65+ raw_graph_dim = raw_graph_node [node_name ].tensor_shape
66+ old_graph_dim = old_graph_node [node_name ].tensor_shape
67+ if raw_graph_dim != old_graph_dim :
68+ raise RuntimeError ("old graph and raw graph has different" + node_name + " dim" )
4669
4770
4871def load_transform_node (graph ):
@@ -59,5 +82,5 @@ def load_transform_node(graph):
5982 "
6083 for node in graph .node :
6184 if re .fullmatch (transform_node_pattern ,node .name ) != None :
62- transform_node [node .name ] = node
85+ transform_node [node .name ] = node . attr [ "value" ]. tensor
6386 return transform_node
0 commit comments