Skip to content

Commit bb7df35

Browse files
Optimization for tflite loops (#1289)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 3dd7745 commit bb7df35

File tree

5 files changed

+211
-3
lines changed

5 files changed

+211
-3
lines changed

tf2onnx/graph.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,9 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
465465
self.graph_name = graph_name or utils.make_name("tf2onnx")
466466
self._is_subgraph = is_subgraph
467467
self.ta_reads = []
468+
# A list of index, output tuples of potential scan outputs in this graph
469+
# Used by the tflite while loop handler
470+
self.scan_outputs = []
468471
self.func_inputs = []
469472

470473
self._target = set(target)

tf2onnx/tflite_handlers/tfl_controlflow.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from tf2onnx.handler import tfl_op
1313
from tf2onnx import utils
1414
from tf2onnx.tf_loader import find_function
15+
from tf2onnx.graph_builder import GraphBuilder
1516
from tf2onnx.onnx_opset.controlflow import parameter_binding, inline_subgraph
1617

1718

@@ -40,6 +41,19 @@ def version_7(cls, ctx, node, **kwargs):
4041
cond_binding = parameter_binding(cond_graph, tfl_while_inputs)
4142
cond_outputs = inline_subgraph(ctx, cond_graph, cond_name, cond_binding)
4243

44+
# Potential scan output candidates are identified in the body subgraph using tfl_scan_output_rewriter.
45+
# They can then be optimized in this tfl loop handler provided they are not used in the cond subgraph.
46+
scan_outputs = sorted(body.scan_outputs, reverse=True)
47+
def input_is_unused(g, index):
48+
return len(g.find_output_consumers(g.func_inputs[index])) == 0
49+
scan_outputs = [(i, out) for i, out in scan_outputs if input_is_unused(cond_graph, i)]
50+
51+
for idx, _ in scan_outputs:
52+
del tfl_while_inputs[idx]
53+
output_shapes.append(output_shapes.pop(idx))
54+
output_dtypes.append(output_dtypes.pop(idx))
55+
output_names.append(output_names.pop(idx))
56+
4357
max_iterations = ctx.make_const(utils.make_name("max_iterations"), np.array(np.iinfo(np.int64).max))
4458

4559
loop_node = ctx.make_node("Loop", [max_iterations.output[0], cond_outputs[0]] + tfl_while_inputs,
@@ -52,15 +66,21 @@ def version_7(cls, ctx, node, **kwargs):
5266
for k, v in output_map.items():
5367
ctx.replace_all_inputs(k, v) # ops=ctx.get_nodes()
5468

55-
body = wire_tfl_while_body(body, loop_node.inputs, output_shapes, output_dtypes, cond_graph)
69+
body = wire_tfl_while_body(body, loop_node.inputs, output_shapes, output_dtypes, cond_graph, scan_outputs)
70+
71+
for i in range(len(scan_outputs)):
72+
squeeze_node = GraphBuilder(body).make_squeeze(
73+
{'data': body.outputs[-1-i], "axes": [0]}, return_node=True)
74+
body.outputs[-1-i] = squeeze_node.output[0]
5675

5776
loop_node.set_body_graph_as_attr("body", body)
5877

5978
def wire_tfl_while_body(g, loop_node_inputs, output_shapes,
60-
output_dtypes, cond_graph):
79+
output_dtypes, cond_graph, scan_outputs):
6180
"""Wire subgraph graph into main."""
6281

6382
g = copy.deepcopy(g)
83+
graph_inputs = g.func_inputs.copy()
6484

6585
# onnx will pass in cond as argument
6686
iter_node = g.make_node("Placeholder", [], name=utils.make_name("iteration_num"),
@@ -69,6 +89,28 @@ def wire_tfl_while_body(g, loop_node_inputs, output_shapes,
6989
output_count=1, dtypes=[TensorProto.BOOL], shapes=[[]])
7090
cond_binding = parameter_binding(cond_graph, g.outputs)
7191

92+
to_remove = set()
93+
for idx, scan_output in scan_outputs:
94+
inp = g.get_node_by_output(graph_inputs[idx])
95+
96+
# Remove consumers of scan input
97+
stack = [inp]
98+
while stack:
99+
node = stack.pop()
100+
if node not in to_remove:
101+
to_remove.add(node)
102+
for out in node.output:
103+
stack += g.find_output_consumers(out)
104+
105+
# Remove scan input from cond graph
106+
cond_binding = {k: "@@ALLOC" if v == g.outputs[idx] else v for k, v in cond_binding.items()}
107+
del g.func_inputs[idx]
108+
del g.outputs[idx]
109+
g.outputs.append(scan_output)
110+
111+
for node in to_remove:
112+
g.remove_node(node.name)
113+
72114
# in onnx the body inputs are: index, cond, [loop_vars]
73115
g.func_inputs = [iter_node.output[0], cond_node.output[0]] + g.func_inputs
74116
# tell graph lib to keep inputs in order

tf2onnx/tflite_rewriters/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
"""tf2onnx.tflite_rewriters module"""
4+
5+
from . import (
6+
tfl_scan_output_rewriter
7+
)
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
4+
"""
5+
tf2onnx.tflite_rewriters.tfl_scan_output_rewriter - Identify a common slice/concat pattern in tflite subgraphs
6+
Effectively replace A = A[:i] + [B] + A[i+1:] with A[i] = B
7+
"""
8+
import numpy as np
9+
10+
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
11+
12+
13+
# pylint: disable=missing-docstring
14+
15+
def rewrite_slice_concat_to_scatter(g, ops):
16+
pattern0 = \
17+
OpTypePattern('TFL_CONCATENATION', name='concat', inputs=[
18+
OpTypePattern('TFL_SLICE', name='begin_slice'),
19+
OpTypePattern('*', name='middle'),
20+
OpTypePattern('TFL_SLICE', name='end_slice')
21+
])
22+
23+
matcher = GraphMatcher(pattern0, allow_reorder=False)
24+
match_results = list(matcher.match_ops(ops))
25+
if match_results:
26+
for match in match_results:
27+
concat = match.get_op("concat")
28+
begin_slice = match.get_op("begin_slice")
29+
middle = match.get_op("middle")
30+
end_slice = match.get_op("end_slice")
31+
middle_shape = g.get_shape(middle.output[0])
32+
33+
# Both slices must be slicing the same tensor
34+
if begin_slice.input[0] != end_slice.input[0]:
35+
continue
36+
original_tensor = begin_slice.input[0]
37+
if concat.get_attr_int("axis") != 0:
38+
continue
39+
# The inserted slice must have length 1 (to be a single index)
40+
if middle_shape is None or len(middle_shape) == 0 or middle_shape[0] != 1:
41+
continue
42+
rank = len(middle_shape)
43+
scan_output = middle.output[0]
44+
if not begin_slice.inputs[1].is_const() or not end_slice.inputs[2].is_const():
45+
continue
46+
# The first slice must start from the beginning (0) for all dims
47+
if not all(v == 0 for v in begin_slice.inputs[1].get_tensor_value()):
48+
continue
49+
# The second slice must slice to the end (-1) for all dims
50+
if not all(v == -1 for v in end_slice.inputs[2].get_tensor_value()):
51+
continue
52+
# The other slice dims are assembled by concatenation if rank > 1
53+
if rank > 1:
54+
begin_concat = begin_slice.inputs[2]
55+
end_concat = end_slice.inputs[1]
56+
if not begin_concat.type == "TFL_CONCATENATION":
57+
continue
58+
if not end_concat.type == "TFL_CONCATENATION":
59+
continue
60+
# Except for dim 0, slice from beginning to end
61+
if not all(get_uniform_const_val(inp) == -1 for inp in begin_concat.inputs[1:]):
62+
continue
63+
if not all(get_uniform_const_val(inp) == 0 for inp in end_concat.inputs[1:]):
64+
continue
65+
begin_idx = begin_concat.inputs[0]
66+
end_idx = end_concat.inputs[0]
67+
else:
68+
begin_idx = begin_slice.inputs[2]
69+
end_idx = end_slice.inputs[1]
70+
# For dim 0, slice to i for first part and from i+1 for second
71+
if not node_is_one_plus_node(begin_idx, end_idx):
72+
continue
73+
out1, _ = get_out_and_offset(begin_idx)
74+
graph_inps = [n.output[0] for n in g.inputs]
75+
# To be a scan output, i must be a graph input
76+
if out1 not in graph_inps:
77+
continue
78+
# The array being sliced must be a graph input
79+
if original_tensor not in graph_inps:
80+
continue
81+
# The input/output index of i
82+
idx = graph_inps.index(out1)
83+
# The input/output index of the array
84+
scan_output_idx = graph_inps.index(original_tensor)
85+
# For a scan output, i must be assigned to i+1 with each iteration
86+
if not node_is_one_plus_node(g.get_node_by_output(out1), g.get_node_by_output(g.outputs[idx])):
87+
continue
88+
if len(g.find_output_consumers(concat.output[0])) > 1:
89+
continue
90+
91+
if g.opset < 10 and len(g.find_output_consumers(concat.output[0])) <= 1:
92+
# If opset is < 10, conversion of the subgraph will fail unless we remove the slice nodes
93+
# We add a tmp node to replace them.
94+
shape = g.get_shape(concat.output[0])
95+
dtype = g.get_dtype(concat.output[0])
96+
tmp_node = g.make_node("TMP_SCAN_OUTPUT", [original_tensor, scan_output],
97+
shapes=[shape], dtypes=[dtype])
98+
g.replace_all_inputs(concat.output[0], tmp_node.output[0])
99+
100+
to_remove = []
101+
out = g.outputs[scan_output_idx]
102+
node = g.get_node_by_output(out)
103+
to_remove.append(node)
104+
105+
while len(node.input) > 0 and node != concat:
106+
out = node.input[0]
107+
node = g.get_node_by_output(out)
108+
to_remove.append(node)
109+
110+
to_remove += [begin_slice, end_slice, concat]
111+
112+
out = original_tensor
113+
node = g.get_node_by_output(out)
114+
to_remove.append(node)
115+
116+
while len(node.input) > 0:
117+
out = node.input[0]
118+
node = g.get_node_by_output(out)
119+
to_remove.append(node)
120+
121+
if not g.is_safe_to_remove_nodes(to_remove):
122+
continue
123+
124+
g.scan_outputs.append((scan_output_idx, scan_output))
125+
return ops
126+
127+
def get_uniform_const_val(n):
128+
if not n.is_const():
129+
return None
130+
v = n.get_tensor_value(as_list=False).flatten()
131+
if len(v) == 0:
132+
return None
133+
if np.all(v == v[0]):
134+
return v[0]
135+
return None
136+
137+
def get_out_and_offset(n):
138+
if n.type in ['TFL_RESHAPE', 'TFL_IDENTITY', 'Identity']:
139+
return get_out_and_offset(n.inputs[0])
140+
if n.type == 'TFL_ADD':
141+
v1 = get_uniform_const_val(n.inputs[0])
142+
v2 = get_uniform_const_val(n.inputs[1])
143+
if v1 is not None and v2 is not None:
144+
return '', v1 + v2
145+
if v1 is not None:
146+
inp2, o2 = get_out_and_offset(n.inputs[1])
147+
return inp2, v1 + o2
148+
if v2 is not None:
149+
inp1, o1 = get_out_and_offset(n.inputs[0])
150+
return inp1, v2 + o1
151+
return n.output[0], 0
152+
153+
def node_is_one_plus_node(node, one_plus_node):
154+
n1, o1 = get_out_and_offset(node)
155+
n2, o2 = get_out_and_offset(one_plus_node)
156+
return n1 == n2 and o1 + 1 == o2

tf2onnx/tflite_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def parse_tflite_graph(tflite_g, opcodes_map, model, input_prefix=''):
180180
output_shapes[name] = tensor.ShapeSignatureAsNumpy().tolist()
181181
buf = model.Buffers(tensor.Buffer())
182182
dtypes[name] = map_tflite_dtype_to_onnx(tensor.Type())
183-
if not buf.DataIsNone():
183+
if not buf.DataIsNone() and tensor.Buffer() > 0:
184184
# For const values we use TF to decode the binary data from the buffer
185185
t = tensor_pb2.TensorProto()
186186
t.tensor_content = buf.DataAsNumpy().tobytes()

0 commit comments

Comments
 (0)