Skip to content

Commit bc7b7ae

Browse files
authored
Merge pull request #522 from nbcsm/yolo
fix conv_convert_inputs bug
2 parents e59757b + 64596bc commit bc7b7ae

File tree

6 files changed

+11
-16
lines changed

6 files changed

+11
-16
lines changed

tests/run_pretrained_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,8 +251,8 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
251251
# convert model to onnx
252252
onnx_graph = self.to_onnx(sess.graph, opset=opset, extra_opset=extra_opset,
253253
shape_override=shape_override, input_names=inputs.keys())
254+
onnx_graph = optimizer.optimize_graph(onnx_graph)
254255
model_proto = onnx_graph.make_model("converted from tf2onnx")
255-
model_proto = optimizer.optimize_graph(onnx_graph).make_model("optimized")
256256
logger.info("To_ONNX, OK")
257257
if onnx_file:
258258
self.create_onnx_file(name, model_proto, inputs, onnx_file)

tf2onnx/onnx_opset/nn.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,10 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
5555
# transpose input if needed, no need to record shapes on input
5656
for idx in input_indices:
5757
parent = node.inputs[idx]
58-
if node.inputs[idx].is_const():
59-
# if input is a constant, transpose that one
60-
if not parent.data_format:
61-
val = parent.get_tensor_value(as_list=False)
62-
parent.set_tensor_value(val.transpose(constants.NHWC_TO_NCHW))
58+
if node.inputs[idx].is_const() and len(ctx.find_output_consumers(node.input[1])) == 1:
59+
# if input is a constant, transpose that one if we are the only consumer
60+
val = parent.get_tensor_value(as_list=False)
61+
parent.set_tensor_value(val.transpose(constants.NHWC_TO_NCHW))
6362
else:
6463
# if input comes from a op, insert transpose op
6564
input_name = node.input[idx]
@@ -70,33 +69,27 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
7069
if shape is not None:
7170
new_shape = spatial_map(shape, constants.NHWC_TO_NCHW)
7271
ctx.set_shape(transpose.output[0], new_shape)
73-
parent.data_format = "NCHW"
7472

7573
# kernel must to be transposed
7674
if with_kernel:
7775
parent = node.inputs[1]
7876
need_transpose = True
7977
if node.inputs[1].is_const():
8078
# kernel is const - transpose the const if we are the only consumer of const
81-
# TODO: maybe we should make a copy of the const, or look at the other consumers
82-
# if they'd want a transose as well.
8379
consumers = ctx.find_output_consumers(node.input[1])
8480
if len(consumers) == 1:
8581
val = parent.get_tensor_value(as_list=False)
8682
val = val.transpose(constants.HWCN_TO_NCHW)
8783
parent.set_tensor_value(val)
88-
parent.data_format = "NCHW"
8984
need_transpose = False
9085

9186
if need_transpose:
9287
input_name = node.input[1]
9388
transpose = ctx.insert_new_node_on_input(node, "Transpose", input_name)
9489
transpose.set_attr("perm", constants.HWCN_TO_NCHW)
9590
transpose.skip_conversion = True
96-
ctx.copy_shape(input_name, transpose.output[0])
9791
new_shape = spatial_map(ctx.get_shape(input_name), constants.HWCN_TO_NCHW)
9892
ctx.set_shape(transpose.output[0], new_shape)
99-
parent.data_format = "NCHW"
10093

10194
# some onnx conv ops require the reshape the kernel (ie. depthwise_conv2d)
10295
if new_kernel_shape:
@@ -129,7 +122,7 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
129122
ctx.set_shape(transpose.output[0], output_shape)
130123
# Transpose TF NHWC shape back to NCHW shape for current ONNX conv node output
131124
ctx.set_shape(output_name, spatial_map(output_shape, constants.NHWC_TO_NCHW))
132-
node.data_format = "NCHW"
125+
node.data_format = "NCHW"
133126

134127

135128
def add_padding(ctx, node, kernel_shape, strides, dilations=None, spatial=2):

tf2onnx/optimizer/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def optimize_graph(graph):
5050
diff = copy.deepcopy(after)
5151
diff.subtract(before)
5252
diff = ["{} {} ({}->{})".format(k, str(v) if v < 0 else '+' + str(v), before.get(k, 0), after.get(k, 0))
53-
for k, v in diff.most_common() if v != 0]
53+
for k, v in sorted(diff.items()) if v != 0]
5454
logger.info("After optimization: %s", ', '.join(diff) if diff else "no change")
5555

5656
return graph

tf2onnx/optimizer/const_fold_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def _fold_node(self, node, graph):
6969
const_outputs = process_func(node, graph)
7070
self._replace_node_with_const(node, graph, const_outputs)
7171
return True
72-
self.logger.debug("need to add function to fold op %s whose op_type is %s", node.name, node.type)
72+
self.logger.debug("need to add function to fold op %s whose op_type is %s", node.name, node.type)
7373
return False
7474

7575
@staticmethod

tf2onnx/optimizer/optimizer_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,5 +63,5 @@ def _print_stat_diff(self, before, after):
6363
diff = copy.deepcopy(after)
6464
diff.subtract(before)
6565
diff = ["{} {} ({}->{})".format(k, str(v) if v < 0 else '+' + str(v), before.get(k, 0), after.get(k, 0))
66-
for k, v in diff.most_common() if v != 0]
66+
for k, v in sorted(diff.items()) if v != 0]
6767
self.logger.verbose(', '.join(diff) if diff else "no change")

tf2onnx/shape_inference.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
def infer_shape(tf_graph, shape_override):
2525
"""Infer shape for TF graph with shape_override set first."""
2626
if shape_override:
27+
logger.info("Apply shape override:")
2728
for name, shape in shape_override.items():
29+
logger.info("\tSet %s shape to %s", name, shape)
2830
tf_graph.get_tensor_by_name(name).set_shape(shape)
2931
tf_graph = reload_tf_graph(tf_graph)
3032

0 commit comments

Comments
 (0)