Skip to content

Commit 1cb41b4

Browse files
Tom/tflite flexops (#1374)
* Fix parsing of tflite string tensors and other tflite fixes Signed-off-by: Tom Wildenhain <[email protected]> * WIP Signed-off-by: Tom Wildenhain <[email protected]> * Work around string decoding bug for flex ops Signed-off-by: Tom Wildenhain <[email protected]> * Topsort tflite subgraphs Signed-off-by: Tom Wildenhain <[email protected]> * Disable some tflite tests Signed-off-by: Tom Wildenhain <[email protected]> * Pylint Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 4c29b91 commit 1cb41b4

File tree

10 files changed

+197
-67
lines changed

10 files changed

+197
-67
lines changed

tests/backend_test_base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,10 @@ def convert_to_tflite(self, graph_def, feed_dict, outputs):
197197
sess_outputs = [sess.graph.get_tensor_by_name(n) for n in outputs]
198198
converter = tf_lite.TFLiteConverter.from_session(sess, sess_inputs, sess_outputs)
199199
#converter.optimizations = [tf.lite.Optimize.DEFAULT]
200+
converter.target_spec.supported_ops = [
201+
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
202+
tf.lite.OpsSet.SELECT_TF_OPS, # enable TensorFlow flex ops.
203+
]
200204

201205
from tensorflow.lite.python.convert import ConverterError
202206
try:

tests/test_backend.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1531,6 +1531,7 @@ def func(data, segments):
15311531
self._run_test_case(func, [_OUTPUT], {_INPUT: data_val, _INPUT1: segs_val})
15321532

15331533
@check_opset_min_version(11, "Pad")
1534+
@skip_tflite("unknown rank")
15341535
def test_segment_mean_unknown_rank(self):
15351536
segs_val = np.array([0, 0, 0, 1, 2, 2, 3, 3], dtype=np.int32)
15361537
data_val = np.arange(8 * 2 * 3, dtype=np.float32).reshape([8, 2, 3])
@@ -1820,7 +1821,7 @@ def func():
18201821
return tf.identity(x_, name=_TFOUTPUT)
18211822
# since results are random, compare the shapes only
18221823
g = self._run_test_case(func, [_OUTPUT], {}, check_value=False, check_shape=True)
1823-
results = self.run_backend(g, [_OUTPUT], {})
1824+
results = self.run_backend(g, g.outputs, {})
18241825
numbers = set(results[0].flatten())
18251826
self.assertEqual(sorted(numbers), list(range(2, 10)))
18261827

@@ -1833,7 +1834,7 @@ def func():
18331834
return tf.identity(x_, name=_TFOUTPUT)
18341835
# since results are random, compare the shapes only
18351836
g = self._run_test_case(func, [_OUTPUT], {}, check_value=False, check_shape=True)
1836-
results = self.run_backend(g, [_OUTPUT], {})
1837+
results = self.run_backend(g, g.outputs, {})
18371838
self.assertTrue(2 <= results[0] < 10)
18381839

18391840
def test_randomuniform_int_nonconst_max(self):
@@ -1845,7 +1846,11 @@ def func(m):
18451846
x_ = tf.identity(x_, name="output2")
18461847
return tf.identity(x_, name=_TFOUTPUT)
18471848
g = self._run_test_case(func, [_OUTPUT], {_INPUT: m_val}, check_value=False, check_shape=True)
1848-
results = self.run_backend(g, [_OUTPUT], {_INPUT: m_val})
1849+
feed_dict = {_INPUT: m_val}
1850+
if "input" in g.input_names:
1851+
# TFLite inputs don't have port numbers
1852+
feed_dict = {k.split(":")[0]: v for k, v in feed_dict.items()}
1853+
results = self.run_backend(g, g.outputs, feed_dict)
18491854
numbers = set(results[0].flatten())
18501855
self.assertEqual(sorted(numbers), list(range(8)))
18511856

@@ -1859,7 +1864,11 @@ def func(n, m):
18591864
x_ = tf.identity(x_, name="output2")
18601865
return tf.identity(x_, name=_TFOUTPUT)
18611866
g = self._run_test_case(func, [_OUTPUT], {_INPUT: n_val, _INPUT1: m_val}, check_value=False, check_shape=True)
1862-
results = self.run_backend(g, [_OUTPUT], {_INPUT: n_val, _INPUT1: m_val})
1867+
feed_dict = {_INPUT: n_val, _INPUT1: m_val}
1868+
if "input" in g.input_names:
1869+
# TFLite inputs don't have port numbers
1870+
feed_dict = {k.split(":")[0]: v for k, v in feed_dict.items()}
1871+
results = self.run_backend(g, g.outputs, feed_dict)
18631872
numbers = set(results[0].flatten())
18641873
self.assertEqual(sorted(numbers), list(range(2, 10)))
18651874

@@ -1875,7 +1884,11 @@ def func(n, m, s):
18751884
return tf.identity(x_, name=_TFOUTPUT)
18761885
g = self._run_test_case(func, [_OUTPUT], {_INPUT: n_val, _INPUT1: m_val, _INPUT2: s_val},
18771886
check_value=False, check_shape=True)
1878-
results = self.run_backend(g, [_OUTPUT], {_INPUT: n_val, _INPUT1: m_val, _INPUT2: s_val})
1887+
feed_dict = {_INPUT: n_val, _INPUT1: m_val, _INPUT2: s_val}
1888+
if "input" in g.input_names:
1889+
# TFLite inputs don't have port numbers
1890+
feed_dict = {k.split(":")[0]: v for k, v in feed_dict.items()}
1891+
results = self.run_backend(g, g.outputs, feed_dict)
18791892
numbers = set(results[0].flatten())
18801893
self.assertEqual(sorted(numbers), list(range(2, 10)))
18811894

@@ -4097,6 +4110,7 @@ def func(splits, rt_dense_values, indices):
40974110

40984111
@check_tf_min_version("1.14", "ragged needs tf 1.14")
40994112
@check_opset_min_version(11, "CumSum")
4113+
@skip_tflite("unknown rank")
41004114
def test_ragged_tensor_to_tensor(self):
41014115
splits_val1 = np.array([0, 1, 1, 5], dtype=np.int32)
41024116
splits_val2 = np.array([0, 3, 3, 5, 9, 10], dtype=np.int32)
@@ -4758,6 +4772,7 @@ def func(input_val):
47584772
self.config.opset = current_opset
47594773

47604774
@check_tf_min_version("1.14")
4775+
@skip_tflite("FlexRFFT2D")
47614776
def test_rfft_ops(self):
47624777

47634778
def dft_slow(x, M):

tests/test_tflite_postprocess.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,9 @@ def make_postprocess_model(self, max_detections=10, detections_per_class=100, ma
167167
operators = builder.EndVector(1)
168168

169169
# subgraph
170+
graph_name = builder.CreateString("TFLite graph")
170171
SubGraph.SubGraphStart(builder)
172+
SubGraph.SubGraphAddName(builder, graph_name)
171173
SubGraph.SubGraphAddTensors(builder, tensors)
172174
SubGraph.SubGraphAddInputs(builder, inputs)
173175
SubGraph.SubGraphAddOutputs(builder, outputs)

tests/test_tflite_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ def func(a, b, c):
5959
self.assertEqual(1, len(tflite_graphs))
6060
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, inputs, outputs, _ = \
6161
parse_tflite_graph(tflite_graphs[0], opcodes_map, model, tensor_shapes_override=tensor_shapes)
62-
self.assertEqual(2, op_cnt['MUL'])
63-
self.assertEqual(1, op_cnt['ADD'])
64-
self.assertEqual(1, op_cnt['FULLY_CONNECTED'])
62+
self.assertEqual(2, op_cnt['TFL_MUL'])
63+
self.assertEqual(1, op_cnt['TFL_ADD'])
64+
self.assertEqual(1, op_cnt['TFL_FULLY_CONNECTED'])
6565

6666
self.assertEqual(1, attr_cnt['WeightsFormat'])
6767
self.assertEqual(names, inputs)

tf2onnx/flexbuffers.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,12 @@ def read_float(buffer, offset, bit_size):
3232
raise FlexbufferParseException("Invalid bit size for flexbuffer float: %d" % bit_size)
3333

3434

35-
def read_string(buffer, offset, size):
36-
return buffer[offset:offset+size].decode('utf-8')
35+
def read_string(buffer, offset, size, decode_strings):
36+
data = buffer[offset:offset+size]
37+
if decode_strings:
38+
# Flexbuffer requires all strings to be valid UTF-8 but FlexOps don't always respect this.
39+
data = data.decode('utf-8')
40+
return data
3741

3842

3943
def read_indirect(buffer, offset, bit_size):
@@ -44,16 +48,16 @@ def read_bytes(buffer, offset, size):
4448
return buffer[offset:offset+size]
4549

4650

47-
def read_array(buffer, offset, length, bit_size, packed_type):
51+
def read_array(buffer, offset, length, bit_size, packed_type, decode_strings):
4852
byte_size = 1 << bit_size
4953
arr = []
5054
for i in range(length):
5155
item_offset = offset + (i * byte_size)
52-
arr.append(read_buffer(buffer, item_offset, bit_size, packed_type))
56+
arr.append(read_buffer(buffer, item_offset, bit_size, packed_type, decode_strings))
5357
return arr
5458

5559

56-
def read_buffer(buffer, offset, parent_bit_size, packed_type):
60+
def read_buffer(buffer, offset, parent_bit_size, packed_type, decode_strings):
5761
"""Recursively decode flatbuffer object into python representation"""
5862
bit_size = packed_type & 3
5963
value_type = packed_type >> 2
@@ -64,20 +68,22 @@ def read_buffer(buffer, offset, parent_bit_size, packed_type):
6468
if value_type in [0x1, 0x2, 0x3]:
6569
read_fn = {0x1: read_int, 0x2: read_uint, 0x3: read_float}[value_type]
6670
return read_fn(buffer, offset, parent_bit_size)
67-
if value_type in [0x4, 0x5]:
71+
if value_type == 0x4:
6872
str_offset = read_indirect(buffer, offset, parent_bit_size)
6973
size = 0
7074
while read_int(buffer, str_offset + size, 0) != 0:
7175
size += 1
72-
return read_string(buffer, str_offset, size)
76+
return read_string(buffer, str_offset, size, decode_strings)
7377
if value_type == 0x5:
7478
str_offset = read_indirect(buffer, offset, parent_bit_size)
75-
size_byte_size = 1 << bit_size
79+
size_bit_size = bit_size
80+
size_byte_size = 1 << size_bit_size
7681
size = read_uint(buffer, str_offset - size_byte_size, bit_size)
7782
while read_int(buffer, str_offset + size, 0) != 0:
7883
size_byte_size <<= 1
79-
size = read_uint(buffer, str_offset - size_byte_size, bit_size)
80-
return read_string(buffer, str_offset, size)
84+
size_bit_size += 1
85+
size = read_uint(buffer, str_offset - size_byte_size, size_bit_size)
86+
return read_string(buffer, str_offset, size, decode_strings)
8187
if value_type in [0x6, 0x7, 0x8]:
8288
read_fn = {0x6: read_int, 0x7: read_uint, 0x8: read_float}[value_type]
8389
data_offset = read_indirect(buffer, offset, parent_bit_size)
@@ -93,10 +99,10 @@ def read_buffer(buffer, offset, parent_bit_size, packed_type):
9399
obj = {}
94100
for i in range(length):
95101
key_offset = keys_vector_offset + i * key_byte_size
96-
key = read_buffer(buffer, key_offset, key_bit_size, (0x4 << 2) | key_bit_size)
102+
key = read_buffer(buffer, key_offset, key_bit_size, (0x4 << 2) | key_bit_size, decode_strings)
97103
value_offset = values_offset + i * byte_size
98104
value_packed_type = read_uint(buffer, packed_types_offset + i, 0)
99-
value = read_buffer(buffer, value_offset, bit_size, value_packed_type)
105+
value = read_buffer(buffer, value_offset, bit_size, value_packed_type, decode_strings)
100106
obj[key] = value
101107
return obj
102108
if value_type == 0xa:
@@ -107,21 +113,21 @@ def read_buffer(buffer, offset, parent_bit_size, packed_type):
107113
for i in range(length):
108114
item_offset = items_offset + (i * byte_size)
109115
packed_type = read_uint(buffer, packed_types_offset + i, 0)
110-
arr.append(read_buffer(buffer, item_offset, bit_size, packed_type))
116+
arr.append(read_buffer(buffer, item_offset, bit_size, packed_type, decode_strings))
111117
return arr
112118
if value_type in [0xb, 0xc, 0xd, 0xe, 0xf, 0x24]:
113119
length_offset = read_indirect(buffer, offset, parent_bit_size) - byte_size
114120
length = read_uint(buffer, length_offset, bit_size)
115121
item_value_type = value_type - 0xb + 0x1
116122
packed_type = item_value_type << 2
117123
items_offset = read_indirect(buffer, offset, parent_bit_size)
118-
return read_array(buffer, items_offset, length, bit_size, packed_type)
124+
return read_array(buffer, items_offset, length, bit_size, packed_type, decode_strings)
119125
if 0x10 <= value_type <= 0x18:
120126
length = (value_type - 0x10) // 3 + 2
121127
value_type = ((value_type - 0x10) % 3) + 1
122128
packed_type = value_type << 2
123129
items_offset = read_indirect(buffer, offset, parent_bit_size)
124-
return read_array(buffer, items_offset, length, bit_size, packed_type)
130+
return read_array(buffer, items_offset, length, bit_size, packed_type, decode_strings)
125131
if value_type == 0x19:
126132
data_offset = read_indirect(buffer, offset, parent_bit_size)
127133
size_offset = data_offset - byte_size
@@ -132,9 +138,9 @@ def read_buffer(buffer, offset, parent_bit_size, packed_type):
132138
raise FlexbufferParseException("Invalid flexbuffer value type %r" % value_type)
133139

134140

135-
def read_flexbuffer(buffer):
141+
def read_flexbuffer(buffer, decode_strings=True):
136142
byte_size = read_uint(buffer, len(buffer) - 1, 0)
137143
bit_size = {1: 0, 2: 1, 4: 2, 8: 3, 16: 4}[byte_size]
138144
packed_type = read_uint(buffer, len(buffer) - 2, 0)
139145
offset = len(buffer) - 2 - byte_size
140-
return read_buffer(buffer, offset, bit_size, packed_type)
146+
return read_buffer(buffer, offset, bit_size, packed_type, decode_strings)

tf2onnx/onnx_opset/math.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,7 @@ def version_10(cls, ctx, node, **kwargs):
584584

585585
shapes = node.output_shapes
586586
dtypes = [onnx_pb.TensorProto.BOOL] * len(node.output_dtypes)
587+
outputs = node.output
587588

588589
ctx.remove_node(node.name)
589590

@@ -593,7 +594,7 @@ def version_10(cls, ctx, node, **kwargs):
593594
shapes=shapes, dtypes=dtypes)
594595
or_node = ctx.make_node("Or", inputs=[inf_node.output[0], nan_node.output[0]], name=utils.make_name(node.name),
595596
shapes=shapes, dtypes=dtypes)
596-
_ = ctx.make_node("Not", inputs=or_node.output, name=node.name,
597+
_ = ctx.make_node("Not", inputs=or_node.output, name=node.name, outputs=outputs,
597598
shapes=shapes, dtypes=dtypes)
598599

599600

tf2onnx/tf_utils.py

Lines changed: 58 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -318,26 +318,75 @@ def replace_placeholders_with_tables(graph_def, placeholder_to_table_info):
318318
n.attr['key_dtype'].type = key_dtype
319319
n.attr['value_dtype'].type = val_dtype
320320

321-
def tflist_to_onnx(g, shape_override, const_node_values=None, ignore_default=None, use_default=None):
322-
"""
323-
Convert the tf-node list into an onnx graph with minimal rewrites so
324-
we can use the onnx graph as intermediate graph.
325-
"""
321+
def read_tf_node_def_attrs(node_def, input_dtypes, input_shapes):
322+
"""Given a tf node def, returns a dict of attribute names to values"""
323+
from tf2onnx.tf_loader import tf_session, tf_placeholder # pylint: disable=import-outside-toplevel
324+
del node_def.input[:]
325+
node_def.name = "node"
326+
327+
# read_tf_node_attrs uses some tf methods that require the node to be loaded into a valid TF graph
328+
g = tf.Graph()
329+
with g.as_default():
330+
for i, (dtype, shape) in enumerate(zip(input_dtypes, input_shapes)):
331+
inp = "input" + str(i)
332+
tf_placeholder(dtype, name=inp, shape=shape)
333+
node_def.input.append(inp)
334+
mini_graph_def = g.as_graph_def()
335+
mini_graph_def.node.append(node_def)
336+
g2 = tf.Graph()
337+
with g2.as_default():
338+
with tf_session() as sess:
339+
tf.import_graph_def(mini_graph_def, name='')
340+
node = sess.graph.get_operation_by_name("node")
341+
return read_tf_node_attrs(node)
342+
343+
344+
def read_tf_node_attrs(node):
345+
"""Given a tf Node, returns a dict of attribute names to values"""
346+
attr = {}
347+
attr_cnt = collections.Counter()
326348

327349
# ignore the following attributes
328-
ignored_attr = {"unknown_rank", "_class", "Tshape", "use_cudnn_on_gpu", "Index", "Tpaddings",
350+
ignored_attr = {"T", "unknown_rank", "_class", "Tshape", "use_cudnn_on_gpu", "Index", "Tpaddings",
329351
"TI", "Tparams", "Tindices", "Tlen", "Tdim", "Tin", "dynamic_size", "Tmultiples",
330352
"Tblock_shape", "Tcrops", "index_type", "Taxis", "U", "maxval",
331353
"Tout", "Tlabels", "Tindex", "element_shape", "Targmax", "Tperm", "Tcond",
332354
"T_threshold", "element_dtype", "shape_type", "_lower_using_switch_merge",
333355
"parallel_iterations", "_num_original_outputs", "output_types", "output_shapes",
334356
"key_dtype", "value_dtype", "Tin", "Tout", "capacity", "component_types", "shapes",
335357
"Toutput_types", "dense_shapes", "Tdense", "Tsegmentids", "Tshift", "Tnumsegments", "SrcT",
358+
"body", "cond", "then_branch", "else_branch", "f",
336359
"Tcomplex", "Treal", # For RFFT, Tcomplex is ignored because
337360
# onnx.helper.make_node fails,
338361
# TODO: it should be added back.
339362
}
340363

364+
for a in node.node_def.attr:
365+
attr_cnt[a] += 1
366+
value = get_tf_node_attr(node, a)
367+
if a in ignored_attr or isinstance(value, tensor_pb2.TensorProto):
368+
pass
369+
elif a == "shape":
370+
shape = get_tf_shape_attr(node)
371+
if shape is not None:
372+
attr[a] = shape
373+
elif a == "DstT":
374+
attr["to"] = map_tf_dtype(value)
375+
elif isinstance(value, tf.DType):
376+
attr[a] = map_tf_dtype(value)
377+
elif isinstance(value, list) and len(value) > 0 and isinstance(value[0], tf.DType):
378+
attr[a] = [map_tf_dtype(v) for v in value]
379+
else:
380+
attr[a] = get_tf_node_attr(node, a)
381+
382+
return attr, attr_cnt
383+
384+
def tflist_to_onnx(g, shape_override, const_node_values=None, ignore_default=None, use_default=None):
385+
"""
386+
Convert the tf-node list into an onnx graph with minimal rewrites so
387+
we can use the onnx graph as intermediate graph.
388+
"""
389+
341390
node_list = g.get_operations()
342391
functions = {}
343392

@@ -360,41 +409,27 @@ def tflist_to_onnx(g, shape_override, const_node_values=None, ignore_default=Non
360409
dtypes[out.name] = map_tf_dtype(out.dtype)
361410
output_shapes[out.name] = shape
362411

363-
# minimal conversion of attributes
364412
for node in ops:
365-
attr = {}
413+
attr, new_attr_cnt = read_tf_node_attrs(node)
414+
attr_cnt += new_attr_cnt
366415
takeit = True
367416
op_cnt[node.type] += 1
368417
for a in node.node_def.attr:
369418
attr_cnt[a] += 1
370419
value = get_tf_node_attr(node, a)
371-
if a in ignored_attr:
372-
pass
373-
elif a == "T":
420+
if a == "T":
374421
if value and not isinstance(value, list):
375422
dtypes[node.name] = map_tf_dtype(value)
376-
elif a == "shape":
377-
shape = get_tf_shape_attr(node)
378-
if shape is not None:
379-
attr[a] = shape
380423
elif a in {"body", "cond", "then_branch", "else_branch", "f"}:
381424
input_shapes = [inp.get_shape() for inp in node.inputs]
382425
nattr = get_tf_node_attr(node, a)
383426
attr[a] = nattr.name
384427
functions[nattr.name] = input_shapes
385-
elif a == "DstT":
386-
attr["to"] = map_tf_dtype(value)
387428
elif isinstance(value, tensor_pb2.TensorProto):
388429
if const_node_values and node.name in const_node_values:
389430
value.tensor_content = const_node_values[node.name]
390431
onnx_tensor = tf_to_onnx_tensor(value, name=port_name(node.name))
391432
attr[a] = onnx_tensor
392-
elif isinstance(value, tf.DType):
393-
attr[a] = map_tf_dtype(value)
394-
elif isinstance(value, list) and len(value) > 0 and isinstance(value[0], tf.DType):
395-
attr[a] = [map_tf_dtype(v) for v in value]
396-
else:
397-
attr[a] = get_tf_node_attr(node, a)
398433

399434
node_type = node.type
400435
input_names = [i.name for i in node.inputs]

tf2onnx/tflite_handlers/tfl_math.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def to_tf(cls, ctx, node, **kwargs):
6666
pass
6767

6868
@tfl_op(["TFL_REDUCE_MAX"], tf_op="Max")
69+
@tfl_op(["TFL_REDUCE_MIN"], tf_op="Min")
6970
@tfl_op(["TFL_REDUCE_ANY"], tf_op="Any")
7071
@tfl_op(["TFL_REDUCE_PROD"], tf_op="Prod")
7172
class TflReduceOp:

0 commit comments

Comments
 (0)