Skip to content

Commit 11fb6a0

Browse files
authored
Merge pull request #508 from nbcsm/fix
Fix mobilenet_v2_1.4_224 failure on tf < 1.8
2 parents 348b4dd + f103eb5 commit 11fb6a0

File tree

5 files changed

+44
-12
lines changed

5 files changed

+44
-12
lines changed

tests/run_pretrained_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
245245
else:
246246
tf_results = self.run_tensorflow(sess, inputs)
247247
logger.info("TensorFlow OK")
248+
248249
model_proto = None
249250
try:
250251
# convert model to onnx
@@ -253,12 +254,11 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
253254
model_proto = onnx_graph.make_model("converted from tf2onnx")
254255
model_proto = optimizer.optimize_graph(onnx_graph).make_model("optimized")
255256
logger.info("To_ONNX, OK")
256-
if utils.is_debug_mode():
257-
onnx_graph.dump_graph()
258257
if onnx_file:
259258
self.create_onnx_file(name, model_proto, inputs, onnx_file)
260259
except Exception:
261260
logger.error("To_ONNX FAIL", exc_info=1)
261+
return False
262262

263263
try:
264264
onnx_results = None

tf2onnx/graph.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,28 @@ def __str__(self):
152152
def __repr__(self):
153153
return "<onnx op type='%s' name=%s>" % (self.type, self._op.name)
154154

155+
@property
156+
def summary(self):
157+
"""Return node summary information."""
158+
lines = []
159+
lines.append("OP={}".format(self.type))
160+
lines.append("Name={}".format(self.name))
161+
162+
g = self.graph
163+
if self.input:
164+
lines.append("Inputs:")
165+
for name in self.input:
166+
node = g.get_node_by_output(name)
167+
op = node.type if node else "N/A"
168+
lines.append("\t{}={}, {}, {}".format(name, op, g.get_shape(name), g.get_dtype(name)))
169+
170+
if self.output:
171+
for name in self.output:
172+
lines.append("Outpus:")
173+
lines.append("\t{}={}, {}".format(name, g.get_shape(name), g.get_dtype(name)))
174+
175+
return '\n'.join(lines)
176+
155177
def get_attr(self, name, default=None):
156178
"""Get raw attribute value."""
157179
attr = self.attr.get(name, default)
@@ -436,6 +458,8 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk
436458
if op_name_scope:
437459
name = "_".join([op_name_scope, name])
438460

461+
logger.debug("Making node: Name=%s, OP=%s", name, op_type)
462+
439463
if outputs is None:
440464
outputs = [name + ":" + str(i) for i in range(output_count)]
441465

@@ -479,6 +503,7 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk
479503
if (not shapes or not dtypes) and infer_shape_dtype:
480504
self.update_node_shape_dtype(node, override=False)
481505

506+
logger.debug("Made node: %s\n%s", node.name, node.summary)
482507
self._nodes.append(node)
483508
return node
484509

@@ -904,7 +929,11 @@ def dump_graph(self):
904929
"""Dump graph with shapes (helpful for debugging)."""
905930
for node in self.get_nodes():
906931
input_names = ["{}{}".format(n, self.get_shape(n)) for n in node.input]
907-
print("{} {} {} {}".format(node.type, self.get_shape(node.output[0]), node.name, ", ".join(input_names)))
932+
logger.debug("%s %s %s %s",
933+
node.type,
934+
self.get_shape(node.output[0]),
935+
node.name,
936+
", ".join(input_names))
908937

909938
def follow_inputs(self, node, num, space=""):
910939
"""Follow inputs for (helpful for debugging)."""

tf2onnx/onnx_opset/nn.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
from tf2onnx.handler import tf_op
2020
from tf2onnx.onnx_opset import common, controlflow, tensor
2121

22-
2322
logger = logging.getLogger(__name__)
2423

24+
2525
# pylint: disable=unused-argument,missing-docstring,unused-variable
2626

2727
def spatial_map(shape, perm):
@@ -151,8 +151,10 @@ def add_padding(ctx, node, kernel_shape, strides, dilations=None, spatial=2):
151151
output_shape = spatial_map(output_shape, constants.NHWC_TO_NCHW)
152152
# calculate pads
153153
if any(input_shape[i + 2] == -1 or output_shape[i + 2] == -1 for i in range(spatial)):
154-
logger.debug("node %s has unknown dim %d for pads calculation, fallback to auto_pad",
155-
node.name, input_shape)
154+
logger.debug(
155+
"node %s has unknown dim for pads calculation, fallback to auto_pad: "
156+
"input_shape=%s, output_shape=%s",
157+
node.name, input_shape, output_shape)
156158
node.set_attr("auto_pad", "SAME_UPPER")
157159
else:
158160
for i in range(spatial):
@@ -327,6 +329,7 @@ def _convert(cls, ctx, node, **kwargs):
327329
add_padding(ctx, node, kernel_shape, strides)
328330
conv_convert_inputs(ctx, node, with_kernel=False)
329331

332+
330333
@tf_op(["MaxPoolWithArgmax"], onnx_op="MaxPool")
331334
class MaxPoolWithArgmaxOp:
332335
@classmethod

tf2onnx/shape_inference.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
"Identity",
2727
"LogicalNot",
2828
"ReverseSequence",
29+
"Relu6",
2930
"Sigmoid",
3031
"Square",
3132
"Tanh"

tf2onnx/tfonnx.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -533,11 +533,14 @@ def rewrite_conv2d_with_pad(g, ops):
533533

534534

535535
def tensorflow_onnx_mapping(g, continue_on_error, ops_mapping):
536+
logger.verbose("Mapping TF node to ONNX node(s)")
536537
mapped_op = collections.Counter()
537538
unmapped_op = collections.Counter()
538539

539540
ops = [n for n in g.get_nodes()]
540541
for node in ops:
542+
logger.debug("Process node: %s\n%s", node.name, node.summary)
543+
541544
if node.need_skip():
542545
logger.debug("explicitly skip node " + node.name)
543546
continue
@@ -574,12 +577,8 @@ def tensorflow_onnx_mapping(g, continue_on_error, ops_mapping):
574577
func(g, node, **kwargs)
575578
node.skip_conversion = True
576579
except Exception as ex:
577-
type_, value_, traceback_ = sys.exc_info()
578-
logger.error("node %s: exception %s" % (node.name, ex))
579-
ex_ext = traceback.format_exception(type_, value_, traceback_)
580-
if continue_on_error:
581-
logger.info(ex_ext)
582-
else:
580+
logger.error("Failed to convert node %s\n%s", node.name, node.summary, exc_info=1)
581+
if not continue_on_error:
583582
raise ex
584583

585584
return mapped_op, unmapped_op

0 commit comments

Comments
 (0)