Skip to content

Commit 448d61d

Browse files
Refactor tflite/tf logic out of tfonnx into tf_utils and tflite_utils (#1599)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 552b0a2 commit 448d61d

File tree

4 files changed

+94
-70
lines changed

4 files changed

+94
-70
lines changed

tf2onnx/graph.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -470,15 +470,10 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
470470
self.ragged_variant_list_reads = []
471471
self.ragged_variant_list_writes = []
472472

473-
self._target = set(target)
474473
self._dtypes = dtypes
475-
476474
self._output_shapes = output_shapes
477-
self._opset = find_opset(opset)
478475

479-
if extra_opset is not None:
480-
utils.make_sure(isinstance(extra_opset, list), "invalid extra_opset")
481-
self._extra_opset = extra_opset
476+
self.set_config(target, opset, extra_opset)
482477

483478
self.outputs = output_names if output_names is not None else []
484479

@@ -537,6 +532,18 @@ def create_new_graph_with_same_config(self):
537532
return Graph([], output_shapes={}, dtypes={}, target=self._target, opset=self._opset,
538533
extra_opset=self.extra_opset, output_names=[])
539534

535+
def set_config(self, target=None, opset=None, extra_opset=None):
536+
"""Set graph fields containing conversion options"""
537+
if target is None:
538+
target = constants.DEFAULT_TARGET
539+
540+
self._opset = find_opset(opset)
541+
self._target = set(target)
542+
543+
if extra_opset is not None:
544+
utils.make_sure(isinstance(extra_opset, list), "invalid extra_opset")
545+
self._extra_opset = extra_opset
546+
540547
@property
541548
def input_names(self):
542549
"""Placeholder node outputs"""

tf2onnx/tflite_utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from tf2onnx.tflite.Model import Model
2020
from tf2onnx.flexbuffers import read_flexbuffer
2121
from tf2onnx.tf_utils import read_tf_node_def_attrs
22+
from tf2onnx.graph import Graph
2223
from tf2onnx import utils
2324

2425
logger = logging.getLogger(__name__)
@@ -129,6 +130,40 @@ def get_options_class(name):
129130
return getattr(module, name)
130131

131132

133+
def graphs_from_tflite(tflite_path, input_names=None, output_names=None):
134+
"""
135+
Given the path to a tflite model, returns a tuple (main_graph, subgraphs) of graph.py Graph objects
136+
inputs/outputs will be taken from main graph in model if not overridden
137+
"""
138+
tflite_graphs, opcodes, model, tensor_shapes = read_tflite_model(tflite_path)
139+
main_g = None
140+
subgraphs = []
141+
for i, tfl_graph in enumerate(tflite_graphs):
142+
is_main_g = i == len(tflite_graphs) - 1
143+
prefix = '' if is_main_g else tfl_graph.Name().decode() + '_'
144+
tensor_shapes_from_interpreter = None
145+
if is_main_g:
146+
tensor_shapes_from_interpreter = tensor_shapes
147+
onnx_nodes, _, _, output_shapes, dtypes, f_inputs, f_outputs, graph_name = \
148+
parse_tflite_graph(tfl_graph, opcodes, model, prefix, tensor_shapes_from_interpreter)
149+
g_inputs = f_inputs
150+
g_outputs = f_outputs
151+
if is_main_g:
152+
# Override IO in main graph
153+
utils.check_io(input_names, output_names, output_shapes.keys())
154+
if input_names is not None:
155+
g_inputs = input_names
156+
if output_names is not None:
157+
g_outputs = output_names
158+
g = Graph(onnx_nodes, output_shapes, dtypes, input_names=g_inputs, output_names=g_outputs,
159+
is_subgraph=not is_main_g, graph_name=graph_name)
160+
if is_main_g:
161+
main_g = g
162+
else:
163+
subgraphs.append(g)
164+
return main_g, subgraphs
165+
166+
132167
def read_tflite_model(tflite_path):
133168
"""
134169
Given the path to a tflite model, returns tuple (tflite_graphs, opcodes_map, model)

tf2onnx/tfonnx.py

Lines changed: 29 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from tf2onnx.shape_inference import infer_shape
2424
from tf2onnx.tf_loader import is_function, resolve_functions, set_function
2525
from tf2onnx.tf_utils import tensorflow_to_onnx, get_tf_version, compute_const_folding_using_tf
26-
from tf2onnx.tflite_utils import read_tflite_model, parse_tflite_graph
26+
from tf2onnx.tflite_utils import graphs_from_tflite
2727

2828
from . import constants, logging, schemas, utils, handler
2929

@@ -417,61 +417,29 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
417417
"please upgrade onnx package to avoid potential conversion issue.",
418418
utils.get_onnx_version(), opset)
419419

420-
if shape_override is None:
421-
shape_override = {}
422420
if inputs_as_nchw is None:
423421
inputs_as_nchw = []
424-
if target is None:
425-
target = constants.DEFAULT_TARGET
426-
427-
def check_io(input_names, output_names, output_shapes):
428-
io_to_check = []
429-
if input_names:
430-
io_to_check.extend(input_names)
431-
if output_names:
432-
io_to_check.extend(output_names)
433-
if io_to_check:
434-
# check output existence in case user passed in wrong output ids
435-
non_exists = set(io_to_check) - set(output_shapes.keys())
436-
if non_exists:
437-
logger.error("\nFailed to convert: inputs/outputs specified do not exist, make sure your passed"
438-
"in format: input/output_node_name:port_id. Problematic inputs/outputs are: %s \n",
439-
non_exists)
440-
raise ValueError("Inputs/Outputs Not Found")
441422

423+
is_tflite = False
442424
if tflite_path is not None:
443-
tflite_graphs, opcodes, model, tensor_shapes = read_tflite_model(tflite_path)
444-
main_g = None
445-
subgraphs = []
446-
for i, tfl_graph in enumerate(tflite_graphs):
447-
is_main_g = i == len(tflite_graphs) - 1
448-
prefix = '' if is_main_g else tfl_graph.Name().decode() + '_'
449-
tensor_shapes_from_interpreter = None
450-
if is_main_g:
451-
tensor_shapes_from_interpreter = tensor_shapes
452-
onnx_nodes, _, _, output_shapes, dtypes, f_inputs, f_outputs, graph_name = \
453-
parse_tflite_graph(tfl_graph, opcodes, model, prefix, tensor_shapes_from_interpreter)
454-
g_inputs = f_inputs
455-
g_outputs = f_outputs
456-
if is_main_g:
457-
# Override IO in main graph
458-
check_io(input_names, output_names, output_shapes)
459-
if input_names is not None:
460-
g_inputs = input_names
461-
if output_names is not None:
462-
g_outputs = output_names
463-
g = Graph(onnx_nodes, output_shapes, dtypes, target, opset, extra_opset, g_inputs, g_outputs,
464-
not is_main_g, graph_name)
465-
if is_main_g:
466-
main_g = g
467-
else:
468-
subgraphs.append(g)
425+
main_g, subgraphs = graphs_from_tflite(tflite_path, input_names, output_names)
426+
is_tflite = True
427+
else:
428+
main_g, subgraphs = graphs_from_tf(tf_graph, input_names, output_names, shape_override, const_node_values,
429+
ignore_default, use_default)
469430

470-
g = process_graphs(main_g, subgraphs, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter,
471-
target, {}, tensors_to_rename, is_tflite=True, dequantize=dequantize)
472-
return g
431+
for g in [main_g] + subgraphs:
432+
g.set_config(target, opset, extra_opset)
433+
g = process_graphs(main_g, subgraphs, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter,
434+
initialized_tables, tensors_to_rename, is_tflite, dequantize)
435+
return g
473436

474-
# make tf2onnx internal subgraphs from the tensorflow subgraphs
437+
438+
def graphs_from_tf(tf_graph, input_names, output_names, shape_override=None, const_node_values=None,
439+
ignore_default=None, use_default=None):
440+
"""make tf2onnx internal subgraphs from the tensorflow subgraphs"""
441+
if shape_override is None:
442+
shape_override = {}
475443
ordered_func = resolve_functions(tf_graph)
476444
subgraphs = []
477445
for func in ordered_func:
@@ -483,7 +451,7 @@ def check_io(input_names, output_names, output_shapes):
483451
onnx_nodes, _, _, output_shapes, dtypes, _ = \
484452
tensorflow_to_onnx(func, shape_override, const_node_values, ignore_default, use_default)
485453

486-
fg = Graph(onnx_nodes, output_shapes, dtypes, target, opset, extra_opset, f_inputs_names, f_output_names,
454+
fg = Graph(onnx_nodes, output_shapes, dtypes, input_names=f_inputs_names, output_names=f_output_names,
487455
is_subgraph=True, graph_name=func.name)
488456
fold_constants_using_tf(fg, outputs_to_values)
489457
subgraphs.append(fg)
@@ -497,33 +465,30 @@ def check_io(input_names, output_names, output_shapes):
497465
onnx_nodes, _, _, output_shapes, dtypes, _ = \
498466
tensorflow_to_onnx(tf_graph, shape_override, const_node_values, ignore_default, use_default)
499467

500-
check_io(input_names, output_names, output_shapes)
501-
main_g = Graph(onnx_nodes, output_shapes, dtypes, target, opset, extra_opset, input_names, output_names,
502-
is_subgraph)
468+
utils.check_io(input_names, output_names, output_shapes.keys())
469+
main_g = Graph(onnx_nodes, output_shapes, dtypes, input_names=input_names, output_names=output_names)
503470
fold_constants_using_tf(main_g, outputs_to_values)
504-
g = process_graphs(main_g, subgraphs, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter,
505-
target, initialized_tables, tensors_to_rename)
506-
return g
471+
return main_g, subgraphs
507472

508473

509-
def process_graphs(main_g, subgraphs, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, target,
474+
def process_graphs(main_g, subgraphs, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter,
510475
initialized_tables, tensors_to_rename, is_tflite=False, dequantize=False):
511476

512477
if tensors_to_rename is not None:
513478
main_g.rename_tensors(tensors_to_rename)
514479
inputs_as_nchw = [tensors_to_rename.get(t, t) for t in inputs_as_nchw]
515480

516481
for g in subgraphs:
517-
fg = process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, target,
482+
fg = process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter,
518483
initialized_tables, is_tflite, dequantize)
519484
set_function(fg.graph_name, fg)
520-
g = process_parsed_graph(main_g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, target,
485+
g = process_parsed_graph(main_g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter,
521486
initialized_tables, is_tflite,
522487
dequantize)
523488
return g
524489

525490

526-
def process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, target,
491+
def process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter,
527492
initialized_tables, is_tflite=False, dequantize=False):
528493

529494
op_cnt, attr_cnt = g.dump_node_statistics(include_attrs=True, include_subgraphs=False)
@@ -628,11 +593,11 @@ def compat_handler(ctx, node, **kwargs):
628593

629594
# post-processing rewriters
630595
late_rewriters = []
631-
if constants.TARGET_RS5 in target:
596+
if g.is_target(constants.TARGET_RS5):
632597
late_rewriters.append(rewrite_incomplete_type_support_rs5)
633-
if constants.TARGET_RS6 in target:
598+
if g.is_target(constants.TARGET_RS6):
634599
late_rewriters.append(rewrite_incomplete_type_support_rs6)
635-
if constants.TARGET_CHANNELS_LAST in target:
600+
if g.is_target(constants.TARGET_CHANNELS_LAST):
636601
late_rewriters.append(rewrite_channels_last)
637602
if late_rewriters:
638603
run_rewriters(g, late_rewriters, continue_on_error)

tf2onnx/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,23 @@ def make_sure(bool_val, error_msg, *args):
260260
raise ValueError("make_sure failure: " + error_msg % args)
261261

262262

263+
def check_io(input_names, output_names, valid_outputs):
264+
"""Asserts that input_names and output_names are contained within valid_outputs else raises an error"""
265+
io_to_check = []
266+
if input_names:
267+
io_to_check.extend(input_names)
268+
if output_names:
269+
io_to_check.extend(output_names)
270+
if io_to_check:
271+
# check output existence in case user passed in wrong output ids
272+
non_exists = set(io_to_check) - set(valid_outputs)
273+
if non_exists:
274+
logger.error("\nFailed to convert: inputs/outputs specified do not exist, make sure your passed"
275+
"in format: input/output_node_name:port_id. Problematic inputs/outputs are: %s \n",
276+
non_exists)
277+
raise ValueError("Inputs/Outputs Not Found")
278+
279+
263280
def is_cpp_protobuf():
264281
return isinstance(ModelProto().ParseFromString, types.BuiltinFunctionType)
265282

0 commit comments

Comments
 (0)