Skip to content

Commit 77e57de

Browse files
go through all op mappings regardless of unsupported ones
1 parent ef0af82 commit 77e57de

File tree

1 file changed

+27
-23
lines changed

1 file changed

+27
-23
lines changed

tf2onnx/tfonnx.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -529,10 +529,11 @@ def rewrite_conv2d_with_pad(g, ops):
529529
return ops
530530

531531

532-
def tensorflow_onnx_mapping(g, continue_on_error, ops_mapping):
532+
def tensorflow_onnx_mapping(g, ops_mapping):
533533
logger.verbose("Mapping TF node to ONNX node(s)")
534534
mapped_op = collections.Counter()
535535
unmapped_op = collections.Counter()
536+
exceptions = []
536537

537538
ops = [n for n in g.get_nodes()]
538539
for node in ops:
@@ -545,40 +546,39 @@ def tensorflow_onnx_mapping(g, continue_on_error, ops_mapping):
545546
op = node.type
546547
map_info = ops_mapping.get(op)
547548
if map_info is None:
548-
if continue_on_error:
549-
unmapped_op[op] += 1
550-
continue
551-
else:
552-
raise ValueError("tensorflow op " + op + " is not supported")
549+
unmapped_op[op] += 1
550+
logger.error("Tensorflow op [%s: %s] is not supported", node.name, op)
551+
continue
553552
mapped_op[op] += 1
553+
554554
func, kwargs = map_info
555555
if kwargs:
556556
# if there is a onnx_op key we'll map the old type to a new type
557557
onnx_op = kwargs.get("onnx_op")
558558
if onnx_op:
559559
node.type = onnx_op
560-
try:
561-
body_graphs = node.get_body_graphs()
562-
if body_graphs:
563-
for attr, b_g in body_graphs.items():
564-
logger.debug("start handling subgraph of %s's attribute %s", node.name, attr)
565-
b_g.topological_sort(b_g.get_nodes())
566-
# we assume only ONNX nodes have subgraph defined in pre-rewriters.
567-
# that means, if we create node having subgraphs in this step, the
568-
# created subgraphs' nodes won't be mapped.
569-
m_ops, unm_ops = tensorflow_onnx_mapping(b_g, continue_on_error, ops_mapping)
570-
mapped_op += m_ops
571-
unmapped_op += unm_ops
572-
logger.debug("finish handling subgraph of %s's attribute %s", node.name, attr)
560+
body_graphs = node.get_body_graphs()
561+
if body_graphs:
562+
for attr, b_g in body_graphs.items():
563+
logger.debug("start handling subgraph of %s's attribute %s", node.name, attr)
564+
b_g.topological_sort(b_g.get_nodes())
565+
# we assume only ONNX nodes have subgraph defined in pre-rewriters.
566+
# that means, if we create node having subgraphs in this step, the
567+
# created subgraphs' nodes won't be mapped.
568+
m_ops, unm_ops, body_exceptions = tensorflow_onnx_mapping(b_g, ops_mapping)
569+
mapped_op += m_ops
570+
unmapped_op += unm_ops
571+
exceptions.extend(body_exceptions)
572+
logger.debug("finish handling subgraph of %s's attribute %s", node.name, attr)
573573

574+
try:
574575
func(g, node, **kwargs)
575576
node.skip_conversion = True
576577
except Exception as ex:
577578
logger.error("Failed to convert node %s\n%s", node.name, node.summary, exc_info=1)
578-
if not continue_on_error:
579-
raise ex
579+
exceptions.append(ex)
580580

581-
return mapped_op, unmapped_op
581+
return mapped_op, unmapped_op, exceptions
582582

583583

584584
def transpose_inputs(ctx, inputs_as_nchw):
@@ -783,7 +783,11 @@ def compat_handler(ctx, node, **kwargs):
783783
g.delete_unused_nodes(output_names)
784784
topological_sort(g, continue_on_error)
785785

786-
mapped_op, unmapped_op = tensorflow_onnx_mapping(g, continue_on_error, ops_mapping)
786+
mapped_op, unmapped_op, exceptions = tensorflow_onnx_mapping(g, ops_mapping)
787+
if unmapped_op:
788+
logger.error("Unsupported ops: %s", unmapped_op)
789+
if exceptions and not continue_on_error:
790+
raise exceptions[0]
787791

788792
# post-processing rewriters
789793
late_rewriters = []

0 commit comments

Comments
 (0)