Skip to content

Commit c0e0a14

Browse files
authored
Update transform.py
1 parent ddcf17c commit c0e0a14

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

source/train/transform.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,22 +49,18 @@ def transform_graph(raw_graph,old_graph):
4949

5050
for node in raw_graph_def.node:
5151
if node.name in raw_graph_node.keys():
52-
"""
53-
if precision_dict[old_graph_node[node.name].dtype][1] == "float16" or precision_dict[raw_graph_node[node.name].dtype][1] == "float16":
54-
raise RuntimeError("float16 conversions not currently supported")
55-
"""
5652

5753
check_dim(raw_graph_node, old_graph_node, node.name)
54+
tensor_shape = [dim.size for dim in raw_graph_node[node.name].tensor_shape.dim]
5855
old_graph_dtype = precision_dict[old_graph_node[node.name].dtype]
5956
raw_graph_dtype = precision_dict[raw_graph_node[node.name].dtype]
6057
print("%s is passed from old graph(%s) to raw graph(%s)" % (node.name, old_graph_dtype[1],raw_graph_dtype[1]))
6158

6259
if raw_graph_dtype[1] == "float16":
6360
if old_graph_dtype[1] == "float64" or old_graph_dtype[1] == "float32":
64-
if re.fullmatch("final_layer_type_\d+/bias", node.name) == None:
61+
if (len(tensor_shape) != 1) or (tensor_shape[0] != 1):
6562
tensor_value = np.frombuffer(old_graph_node[node.name].tensor_content, dtype=old_graph_dtype[0])
6663
tensor_value = tensor_value.astype(np.float16)
67-
tensor_shape = [dim.size for dim in raw_graph_node[node.name].tensor_shape.dim]
6864
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value, tf.float16, tensor_shape)))
6965

7066
else:
@@ -77,34 +73,32 @@ def transform_graph(raw_graph,old_graph):
7773
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value,tf.float16, [1])))
7874

7975
elif old_graph_dtype[1] == "float16":
80-
tensor_shape = [dim.size for dim in raw_graph_node[node.name].tensor_shape.dim]
8176
tensor_value = convertMatrix(np.array(old_graph_node[node.name].half_val), tensor_shape)
8277
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value, tf.float16, tensor_value.shape)))
8378

8479
elif raw_graph_dtype[1] == "float64" or raw_graph_dtype[1] == "float32":
8580
if old_graph_dtype[1] == "float64" or old_graph_dtype[1] == "float32":
86-
if re.fullmatch("final_layer_type_\d+/bias", node.name) == None:
81+
if (len(tensor_shape) != 1) or (tensor_shape[0] != 1):
8782
tensor_value = np.frombuffer(old_graph_node[node.name].tensor_content,dtype = old_graph_dtype[0])
8883
tensor_value = tensor_value.astype(dtype=raw_graph_dtype[0])
8984
node.attr["value"].tensor.tensor_content = tensor_value.tostring()
9085

9186
else:
9287
if old_graph_dtype[1] == "float64":
9388
tensor_value = (np.array(old_graph_node[node.name].double_val)).astype(raw_graph_dtype[0])
89+
print(node.name,tensor_value)
9490
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value,raw_graph_dtype[0], [1])))
9591

9692
elif old_graph_dtype[1] == "float32":
9793
tensor_value = (np.array(old_graph_node[node.name].float_val)).astype(raw_graph_dtype[0])
9894
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value,raw_graph_dtype[0], [1])))
9995

10096
elif old_graph_dtype[1] == "float16":
101-
if re.fullmatch("final_layer_type_\d+/bias", node.name) == None:
102-
tensor_shape = [dim.size for dim in raw_graph_node[node.name].tensor_shape.dim]
97+
if (len(tensor_shape) != 1) or (tensor_shape[0] != 1):
10398
tensor_value = convertMatrix(np.array(old_graph_node[node.name].half_val), tensor_shape)
10499
tensor_value = tensor_value.astype(raw_graph_dtype[0])
105100
node.attr["value"].tensor.tensor_content = tensor_value.tostring()
106101
else:
107-
tensor_shape = [dim.size for dim in raw_graph_node[node.name].tensor_shape.dim]
108102
tensor_value = convertMatrix(np.array(old_graph_node[node.name].half_val), tensor_shape)
109103
tensor_value = tensor_value.astype(raw_graph_dtype[0])
110104
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value,raw_graph_dtype[0], tensor_value.shape)))
@@ -127,8 +121,16 @@ def load_transform_node(graph):
127121
layer_\d+_type_\d+/matrix|\
128122
layer_\d+_type_\d+/bias|\
129123
layer_\d+_type_\d+/idt|\
124+
final_layer_type_\d+/matrix|\
125+
descrpt_attr/t_avg|\
126+
descrpt_attr/t_std|\
130127
final_layer_type_\d+/bias|\
131-
final_layer_type_\d+/matrix\
128+
fitting_attr/t_fparam_avg|\
129+
fitting_attr/t_fparam_istd|\
130+
fitting_attr/t_aparam_avg|\
131+
fitting_attr/t_aparam_istd|\
132+
model_attr/t_tab_info|\
133+
model_attr/t_tab_data|\
132134
"
133135
for node in graph.node:
134136
if re.fullmatch(transform_node_pattern,node.name) != None:

0 commit comments

Comments
 (0)