Skip to content

Commit 552b0a2

Browse files
Refactor tfonnx.py (#1595)
* Refactor tfonnx Signed-off-by: Tom Wildenhain <[email protected]> * Add Identity nodes even for subgraphs Signed-off-by: Tom Wildenhain <[email protected]> * pylint Signed-off-by: Tom Wildenhain <[email protected]> * Remove unnecessary subgraph check Signed-off-by: Tom Wildenhain <[email protected]> * Fix dtype finding for multiple identity nodes in loops) Signed-off-by: Tom Wildenhain <[email protected]>
1 parent feb2dcc commit 552b0a2

File tree

3 files changed

+120
-105
lines changed

3 files changed

+120
-105
lines changed

tf2onnx/graph.py

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -503,35 +503,34 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
503503

504504
self.reset_nodes(ops)
505505

506-
if not is_subgraph:
507-
# add identity node after each output, in case it is renamed during conversion.
508-
for o in self.outputs:
509-
n = self.get_node_by_output_in_current_graph(o)
510-
if n.is_graph_input():
511-
# Don't add identity if the node is also an input. We want to keep input names the same.
512-
continue
513-
new_output_name = port_name(n.name + "_" + utils.make_name("raw_output_"))
514-
n_shapes = n.output_shapes
515-
n_dtypes = n.output_dtypes
516-
body_graphs = n.graph.contained_graphs.pop(n.name, None)
517-
self.remove_node(n.name)
518-
519-
new_outputs = [output if output != o else new_output_name for output in n.output]
520-
# domain should be passed to new node
521-
branches = {}
522-
if body_graphs:
523-
for attr_name, body_graph in body_graphs.items():
524-
body_graph.parent_graph = self
525-
branches[attr_name] = body_graph
506+
# add identity node after each output, in case it is renamed during conversion.
507+
for o in self.outputs:
508+
n = self.get_node_by_output_in_current_graph(o)
509+
if n.is_graph_input():
510+
# Don't add identity if the node is also an input. We want to keep input names the same.
511+
continue
512+
new_output_name = port_name(n.name + "_" + utils.make_name("raw_output_"))
513+
n_shapes = n.output_shapes
514+
n_dtypes = n.output_dtypes
515+
body_graphs = n.graph.contained_graphs.pop(n.name, None)
516+
self.remove_node(n.name)
517+
518+
new_outputs = [output if output != o else new_output_name for output in n.output]
519+
# domain should be passed to new node
520+
branches = {}
521+
if body_graphs:
522+
for attr_name, body_graph in body_graphs.items():
523+
body_graph.parent_graph = self
524+
branches[attr_name] = body_graph
526525

527-
_ = self.make_node(n.type, n.input, outputs=new_outputs, attr=n.attr, name=n.name,
528-
skip_conversion=n._skip_conversion, dtypes=n_dtypes, shapes=n_shapes,
529-
domain=n.domain, branches=branches)
526+
_ = self.make_node(n.type, n.input, outputs=new_outputs, attr=n.attr, name=n.name,
527+
skip_conversion=n._skip_conversion, dtypes=n_dtypes, shapes=n_shapes,
528+
domain=n.domain, branches=branches)
530529

531-
self.replace_all_inputs(o, new_output_name, ops=self.get_nodes())
532-
self.make_node("Identity", [new_output_name], outputs=[o], op_name_scope=n.name + "_" + "graph_outputs")
533-
self.copy_shape(new_output_name, o)
534-
self.copy_dtype(new_output_name, o)
530+
self.replace_all_inputs(o, new_output_name, ops=self.get_nodes())
531+
self.make_node("Identity", [new_output_name], outputs=[o], op_name_scope=n.name + "_" + "graph_outputs")
532+
self.copy_shape(new_output_name, o)
533+
self.copy_dtype(new_output_name, o)
535534

536535
def create_new_graph_with_same_config(self):
537536
"""Create a clean graph inheriting current graph's configuration."""
@@ -874,6 +873,23 @@ def is_const(self, output):
874873
def get_tensor_value(self, output, as_list=True):
875874
return self.get_node_by_output(output).get_tensor_value(as_list)
876875

876+
def rename_tensors(self, tensors_to_rename):
877+
"""Replace tensor names within nodes and graph inputs/outputs"""
878+
def rename_list(l):
879+
return [tensors_to_rename.get(t, t) for t in l]
880+
881+
def rename_keys(d):
882+
return {tensors_to_rename.get(k, k): v for k, v in d.items()}
883+
884+
self._output_to_node_name = rename_keys(self._output_to_node_name)
885+
self._output_to_consumers = rename_keys(self._output_to_consumers)
886+
self._dtypes = rename_keys(self._dtypes)
887+
self._output_shapes = rename_keys(self._output_shapes)
888+
self.outputs = rename_list(self.outputs)
889+
for node in self._nodes:
890+
node._input = rename_list(node._input)
891+
node._output = rename_list(node._output)
892+
877893
def change_node_name(self, node, new_name):
878894
"""Remove node in current graph."""
879895
utils.make_sure(new_name not in self._nodes_by_name, "node %s not unique ", new_name)
@@ -1232,15 +1248,23 @@ def follow_inputs(self, node, num, space=""):
12321248
return []
12331249
return val
12341250

1235-
def dump_node_statistics(self):
1251+
def dump_node_statistics(self, include_attrs=False, include_subgraphs=True):
1252+
"""Return a counter of op types (and optionally attribute names) within the graph"""
12361253
op_cnt = collections.Counter()
1254+
attr_cnt = collections.Counter()
12371255
for n in self.get_nodes():
12381256
op_cnt[n.type] += 1
1257+
for k in n.attr.keys():
1258+
attr_cnt[k] += 1
12391259
body_graphs = n.get_body_graphs()
1240-
if body_graphs:
1260+
if body_graphs and include_subgraphs:
12411261
for b_g in body_graphs.values():
1242-
op_cnt += b_g.dump_node_statistics()
1262+
g_op_cnt, g_attr_cnt = b_g.dump_node_statistics(include_attrs=True, include_subgraphs=True)
1263+
op_cnt += g_op_cnt
1264+
attr_cnt += g_attr_cnt
12431265

1266+
if include_attrs:
1267+
return op_cnt, attr_cnt
12441268
return op_cnt
12451269

12461270
def remove_input(self, node, to_be_removed, input_index=None):

tf2onnx/onnx_opset/controlflow.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -623,12 +623,13 @@ def wire_while_body(parent_g, g, loop_node, body_input_to_state_var, cond_input_
623623

624624
g.outputs = [cond_outputs[0]] + g.outputs[2:] + scan_outputs
625625

626-
# FIXME: onnx does not have a variant type so we try to fish for the dtype in a prior TensorListSetItem.
626+
# onnx does not have a variant type so we try to fish for the dtype in a prior TensorListSetItem.
627627
for o in g.outputs:
628628
if g.get_dtype(o) == onnx_pb.TensorProto.UNDEFINED:
629-
node = g.get_node_by_output(o)
630-
if node.type in ["Identity"]:
631-
g.set_dtype(o, node.inputs[0].output_dtypes[0])
629+
curr_o = o
630+
while g.get_node_by_output(curr_o).type == "Identity":
631+
curr_o = g.get_node_by_output(curr_o).input[0]
632+
g.copy_dtype(curr_o, o)
632633

633634
for node in g.ragged_variant_list_reads:
634635
# Requires opset 11

tf2onnx/tfonnx.py

Lines changed: 61 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
# pylint: disable=useless-return,broad-except,logging-not-lazy,unused-argument,missing-docstring
3434
# pylint: disable=unused-variable
3535

36-
def fold_constants_using_tf(g, outputs_to_values, outputs_to_dtypes):
36+
def fold_constants_using_tf(g, outputs_to_values):
3737
ops = list(g.get_nodes())
3838
# pylint: disable=too-many-nested-blocks
3939
keep_looking = True
@@ -409,14 +409,13 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
409409
del verbose
410410

411411
opset = utils.find_opset(opset)
412-
if not is_subgraph:
413-
logger.info("Using tensorflow=%s, onnx=%s, tf2onnx=%s/%s",
414-
get_tf_version(), utils.get_onnx_version(), tf2onnx.__version__, tf2onnx.version.git_version[:6])
415-
logger.info("Using opset <onnx, %s>", opset)
416-
if opset > schemas.get_max_supported_opset_version():
417-
logger.warning("Currently installed onnx package %s is too low to support opset %s, "
418-
"please upgrade onnx package to avoid potential conversion issue.",
419-
utils.get_onnx_version(), opset)
412+
logger.info("Using tensorflow=%s, onnx=%s, tf2onnx=%s/%s",
413+
get_tf_version(), utils.get_onnx_version(), tf2onnx.__version__, tf2onnx.version.git_version[:6])
414+
logger.info("Using opset <onnx, %s>", opset)
415+
if opset > schemas.get_max_supported_opset_version():
416+
logger.warning("Currently installed onnx package %s is too low to support opset %s, "
417+
"please upgrade onnx package to avoid potential conversion issue.",
418+
utils.get_onnx_version(), opset)
420419

421420
if shape_override is None:
422421
shape_override = {}
@@ -440,34 +439,17 @@ def check_io(input_names, output_names, output_shapes):
440439
non_exists)
441440
raise ValueError("Inputs/Outputs Not Found")
442441

443-
def rename_tensors_in_dict(d):
444-
if tensors_to_rename is None:
445-
return d
446-
return {tensors_to_rename.get(k, k): v for k, v in d.items()}
447-
448-
def rename_tensors_in_list(tensors):
449-
if tensors_to_rename is None or tensors is None:
450-
return tensors
451-
return [tensors_to_rename.get(t, t) for t in tensors]
452-
453-
def rename_tensors_in_nodes(onnx_nodes):
454-
if tensors_to_rename is None:
455-
return
456-
for n in onnx_nodes:
457-
n.input[:] = rename_tensors_in_list(n.input)
458-
n.output[:] = rename_tensors_in_list(n.output)
459-
460442
if tflite_path is not None:
461443
tflite_graphs, opcodes, model, tensor_shapes = read_tflite_model(tflite_path)
462444
main_g = None
463-
inputs_as_nchw = rename_tensors_in_list(inputs_as_nchw)
445+
subgraphs = []
464446
for i, tfl_graph in enumerate(tflite_graphs):
465447
is_main_g = i == len(tflite_graphs) - 1
466448
prefix = '' if is_main_g else tfl_graph.Name().decode() + '_'
467449
tensor_shapes_from_interpreter = None
468450
if is_main_g:
469451
tensor_shapes_from_interpreter = tensor_shapes
470-
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, f_inputs, f_outputs, graph_name = \
452+
onnx_nodes, _, _, output_shapes, dtypes, f_inputs, f_outputs, graph_name = \
471453
parse_tflite_graph(tfl_graph, opcodes, model, prefix, tensor_shapes_from_interpreter)
472454
g_inputs = f_inputs
473455
g_outputs = f_outputs
@@ -478,63 +460,73 @@ def rename_tensors_in_nodes(onnx_nodes):
478460
g_inputs = input_names
479461
if output_names is not None:
480462
g_outputs = output_names
481-
rename_tensors_in_nodes(onnx_nodes)
482-
g_inputs = rename_tensors_in_list(g_inputs)
483-
g_outputs = rename_tensors_in_list(g_outputs)
484-
output_shapes = rename_tensors_in_dict(output_shapes)
485-
dtypes = rename_tensors_in_dict(dtypes)
486-
g = Graph(onnx_nodes, output_shapes, dtypes, target, opset, extra_opset, g_inputs, g_outputs, is_subgraph)
487-
fg = process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, target,
488-
g_outputs, {}, {}, {}, op_cnt, attr_cnt, is_tflite=True, dequantize=dequantize)
489-
fg.graph_name = graph_name
463+
g = Graph(onnx_nodes, output_shapes, dtypes, target, opset, extra_opset, g_inputs, g_outputs,
464+
not is_main_g, graph_name)
490465
if is_main_g:
491-
main_g = fg
466+
main_g = g
492467
else:
493-
set_function(graph_name, fg)
468+
subgraphs.append(g)
469+
470+
g = process_graphs(main_g, subgraphs, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter,
471+
target, {}, tensors_to_rename, is_tflite=True, dequantize=dequantize)
472+
return g
473+
474+
# make tf2onnx internal subgraphs from the tensorflow subgraphs
475+
ordered_func = resolve_functions(tf_graph)
476+
subgraphs = []
477+
for func in ordered_func:
478+
f_inputs_names = [t.name for t in func.inputs]
479+
f_output_names = [t.name for t in func.outputs]
480+
481+
outputs_to_values, _ = compute_const_folding_using_tf(func, const_node_values, output_names)
482+
483+
onnx_nodes, _, _, output_shapes, dtypes, _ = \
484+
tensorflow_to_onnx(func, shape_override, const_node_values, ignore_default, use_default)
494485

495-
return main_g
486+
fg = Graph(onnx_nodes, output_shapes, dtypes, target, opset, extra_opset, f_inputs_names, f_output_names,
487+
is_subgraph=True, graph_name=func.name)
488+
fold_constants_using_tf(fg, outputs_to_values)
489+
subgraphs.append(fg)
496490

497491
is_func = is_function(tf_graph)
498492
if not is_func:
499493
tf_graph = infer_shape(tf_graph, shape_override)
500494

501-
outputs_to_values, outputs_to_dtypes = compute_const_folding_using_tf(tf_graph, const_node_values, output_names)
495+
outputs_to_values, _ = compute_const_folding_using_tf(tf_graph, const_node_values, output_names)
502496

503-
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, _ = \
497+
onnx_nodes, _, _, output_shapes, dtypes, _ = \
504498
tensorflow_to_onnx(tf_graph, shape_override, const_node_values, ignore_default, use_default)
505-
if not is_subgraph:
506-
# make tf2onnx internal subgraphs from the tensorflow subgraphs
507-
ordered_func = resolve_functions(tf_graph)
508-
for func in ordered_func:
509-
f_inputs_names = [t.name for t in func.inputs]
510-
f_output_names = [t.name for t in func.outputs]
511-
fg = process_tf_graph(func, continue_on_error, False, target, opset,
512-
custom_op_handlers, custom_rewriter,
513-
extra_opset, shape_override, inputs_as_nchw,
514-
f_inputs_names, f_output_names, is_subgraph=True,
515-
const_node_values=const_node_values, tensors_to_rename=tensors_to_rename,
516-
initialized_tables=initialized_tables)
517-
fg.graph_name = func.name
518-
set_function(func.name, fg)
519499

520500
check_io(input_names, output_names, output_shapes)
501+
main_g = Graph(onnx_nodes, output_shapes, dtypes, target, opset, extra_opset, input_names, output_names,
502+
is_subgraph)
503+
fold_constants_using_tf(main_g, outputs_to_values)
504+
g = process_graphs(main_g, subgraphs, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter,
505+
target, initialized_tables, tensors_to_rename)
506+
return g
507+
508+
509+
def process_graphs(main_g, subgraphs, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, target,
510+
initialized_tables, tensors_to_rename, is_tflite=False, dequantize=False):
521511

522-
if not is_subgraph:
523-
rename_tensors_in_nodes(onnx_nodes)
524-
input_names = rename_tensors_in_list(input_names)
525-
output_names = rename_tensors_in_list(output_names)
526-
output_shapes = rename_tensors_in_dict(output_shapes)
527-
dtypes = rename_tensors_in_dict(dtypes)
528-
inputs_as_nchw = rename_tensors_in_list(inputs_as_nchw)
529-
g = Graph(onnx_nodes, output_shapes, dtypes, target, opset, extra_opset, input_names, output_names, is_subgraph)
530-
g = process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, target,
531-
output_names, initialized_tables, outputs_to_values, outputs_to_dtypes, op_cnt, attr_cnt)
512+
if tensors_to_rename is not None:
513+
main_g.rename_tensors(tensors_to_rename)
514+
inputs_as_nchw = [tensors_to_rename.get(t, t) for t in inputs_as_nchw]
515+
516+
for g in subgraphs:
517+
fg = process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, target,
518+
initialized_tables, is_tflite, dequantize)
519+
set_function(fg.graph_name, fg)
520+
g = process_parsed_graph(main_g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, target,
521+
initialized_tables, is_tflite,
522+
dequantize)
532523
return g
533524

534525

535526
def process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, target,
536-
output_names, initialized_tables, outputs_to_values, outputs_to_dtypes, op_cnt, attr_cnt,
537-
is_tflite=False, dequantize=False):
527+
initialized_tables, is_tflite=False, dequantize=False):
528+
529+
op_cnt, attr_cnt = g.dump_node_statistics(include_attrs=True, include_subgraphs=False)
538530

539531
if is_tflite:
540532
tfl_rewriters = []
@@ -587,8 +579,6 @@ def compat_handler(ctx, node, **kwargs):
587579
if inputs_as_nchw:
588580
transpose_inputs(g, inputs_as_nchw)
589581

590-
fold_constants_using_tf(g, outputs_to_values, outputs_to_dtypes)
591-
592582
# pre-processing graph rewrites
593583
# bi-directional re-writer should be placed after single directional re-writer
594584
rewriters = [
@@ -626,7 +616,7 @@ def compat_handler(ctx, node, **kwargs):
626616
run_rewriters(g, rewriters, continue_on_error)
627617

628618
# some nodes may already copied into inner Graph, so remove them from main Graph.
629-
g.delete_unused_nodes(output_names)
619+
g.delete_unused_nodes(g.outputs)
630620
topological_sort(g, continue_on_error)
631621

632622
mapped_op, unmapped_op, exceptions = \

0 commit comments

Comments
 (0)