@@ -49,22 +49,18 @@ def transform_graph(raw_graph,old_graph):
4949
5050 for node in raw_graph_def .node :
5151 if node .name in raw_graph_node .keys ():
52- """
53- if precision_dict[old_graph_node[node.name].dtype][1] == "float16" or precision_dict[raw_graph_node[node.name].dtype][1] == "float16":
54- raise RuntimeError("float16 conversions not currently supported")
55- """
5652
5753 check_dim (raw_graph_node , old_graph_node , node .name )
54+ tensor_shape = [dim .size for dim in raw_graph_node [node .name ].tensor_shape .dim ]
5855 old_graph_dtype = precision_dict [old_graph_node [node .name ].dtype ]
5956 raw_graph_dtype = precision_dict [raw_graph_node [node .name ].dtype ]
6057 print ("%s is passed from old graph(%s) to raw graph(%s)" % (node .name , old_graph_dtype [1 ],raw_graph_dtype [1 ]))
6158
6259 if raw_graph_dtype [1 ] == "float16" :
6360 if old_graph_dtype [1 ] == "float64" or old_graph_dtype [1 ] == "float32" :
64- if re . fullmatch ( "final_layer_type_\d+/bias" , node . name ) == None :
61+ if ( len ( tensor_shape ) != 1 ) or ( tensor_shape [ 0 ] != 1 ) :
6562 tensor_value = np .frombuffer (old_graph_node [node .name ].tensor_content , dtype = old_graph_dtype [0 ])
6663 tensor_value = tensor_value .astype (np .float16 )
67- tensor_shape = [dim .size for dim in raw_graph_node [node .name ].tensor_shape .dim ]
6864 node .attr ["value" ].CopyFrom (tf .AttrValue (tensor = tf .make_tensor_proto (tensor_value , tf .float16 , tensor_shape )))
6965
7066 else :
@@ -77,13 +73,12 @@ def transform_graph(raw_graph,old_graph):
7773 node .attr ["value" ].CopyFrom (tf .AttrValue (tensor = tf .make_tensor_proto (tensor_value ,tf .float16 , [1 ])))
7874
7975 elif old_graph_dtype [1 ] == "float16" :
80- tensor_shape = [dim .size for dim in raw_graph_node [node .name ].tensor_shape .dim ]
8176 tensor_value = convertMatrix (np .array (old_graph_node [node .name ].half_val ), tensor_shape )
8277 node .attr ["value" ].CopyFrom (tf .AttrValue (tensor = tf .make_tensor_proto (tensor_value , tf .float16 , tensor_value .shape )))
8378
8479 elif raw_graph_dtype [1 ] == "float64" or raw_graph_dtype [1 ] == "float32" :
8580 if old_graph_dtype [1 ] == "float64" or old_graph_dtype [1 ] == "float32" :
86- if re . fullmatch ( "final_layer_type_\d+/bias" , node . name ) == None :
81+ if ( len ( tensor_shape ) != 1 ) or ( tensor_shape [ 0 ] != 1 ) :
8782 tensor_value = np .frombuffer (old_graph_node [node .name ].tensor_content ,dtype = old_graph_dtype [0 ])
8883 tensor_value = tensor_value .astype (dtype = raw_graph_dtype [0 ])
8984 node .attr ["value" ].tensor .tensor_content = tensor_value .tostring ()
@@ -98,13 +93,11 @@ def transform_graph(raw_graph,old_graph):
9893 node .attr ["value" ].CopyFrom (tf .AttrValue (tensor = tf .make_tensor_proto (tensor_value ,raw_graph_dtype [0 ], [1 ])))
9994
10095 elif old_graph_dtype [1 ] == "float16" :
101- if re .fullmatch ("final_layer_type_\d+/bias" , node .name ) == None :
102- tensor_shape = [dim .size for dim in raw_graph_node [node .name ].tensor_shape .dim ]
96+ if (len (tensor_shape ) != 1 ) or (tensor_shape [0 ] != 1 ):
10397 tensor_value = convertMatrix (np .array (old_graph_node [node .name ].half_val ), tensor_shape )
10498 tensor_value = tensor_value .astype (raw_graph_dtype [0 ])
10599 node .attr ["value" ].tensor .tensor_content = tensor_value .tostring ()
106100 else :
107- tensor_shape = [dim .size for dim in raw_graph_node [node .name ].tensor_shape .dim ]
108101 tensor_value = convertMatrix (np .array (old_graph_node [node .name ].half_val ), tensor_shape )
109102 tensor_value = tensor_value .astype (raw_graph_dtype [0 ])
110103 node .attr ["value" ].CopyFrom (tf .AttrValue (tensor = tf .make_tensor_proto (tensor_value ,raw_graph_dtype [0 ], tensor_value .shape )))
@@ -127,8 +120,16 @@ def load_transform_node(graph):
127120 layer_\d+_type_\d+/matrix|\
128121 layer_\d+_type_\d+/bias|\
129122 layer_\d+_type_\d+/idt|\
123+ final_layer_type_\d+/matrix|\
124+ descrpt_attr/t_avg|\
125+ descrpt_attr/t_std|\
130126 final_layer_type_\d+/bias|\
131- final_layer_type_\d+/matrix\
127+ fitting_attr/t_fparam_avg|\
128+ fitting_attr/t_fparam_istd|\
129+ fitting_attr/t_aparam_avg|\
130+ fitting_attr/t_aparam_istd|\
131+ model_attr/t_tab_info|\
132+ model_attr/t_tab_data|\
132133 "
133134 for node in graph .node :
134135 if re .fullmatch (transform_node_pattern ,node .name ) != None :
0 commit comments