Skip to content

Commit d7b179e

Browse files
Tests for tflite postprocess (#1302)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 5e39d16 commit d7b179e

File tree

10 files changed

+257
-19
lines changed

10 files changed

+257
-19
lines changed

tests/backend_test_base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,11 +210,13 @@ def convert_to_tflite(self, graph_def, feed_dict, outputs):
210210
def run_tflite(self, tflite_path, feed_dict):
211211
try:
212212
interpreter = tf.lite.Interpreter(tflite_path)
213-
interpreter.allocate_tensors()
214213
input_details = interpreter.get_input_details()
215214
output_details = interpreter.get_output_details()
216215
input_name_to_index = {n['name'].split(':')[0]: n['index'] for n in input_details}
217216
feed_dict_without_port = {k.split(':')[0]: v for k, v in feed_dict.items()}
217+
for k, v in feed_dict_without_port.items():
218+
interpreter.resize_tensor_input(input_name_to_index[k], v.shape)
219+
interpreter.allocate_tensors()
218220
# The output names might be different in the tflite but the order is the same
219221
output_names = [n['name'] for n in output_details]
220222
for k, v in feed_dict_without_port.items():

tests/common.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
"check_opset_max_version",
3232
"skip_tf2",
3333
"skip_tflite",
34+
"requires_tflite",
3435
"check_opset_after_tf_version",
3536
"check_target",
3637
"skip_caffe2_backend",
@@ -195,6 +196,13 @@ def test(self):
195196
return decorator
196197

197198

199+
def requires_tflite(message=""):
200+
""" Skip test if tflite tests are disabled """
201+
config = get_test_config()
202+
reason = _append_message("test requires tflite", message)
203+
return unittest.skipIf(config.skip_tflite_tests, reason)
204+
205+
198206
def requires_custom_ops(message=""):
199207
""" Skip until custom ops framework is on PyPI. """
200208
reason = _append_message("test needs custom ops framework", message)

tests/test_tflite_postprocess.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
4+
"""Unit Tests for TFLite_Detection_PostProcess op"""
5+
6+
import os
7+
import struct
8+
import numpy as np
9+
import flatbuffers
10+
11+
from common import * # pylint: disable=wildcard-import,unused-wildcard-import
12+
from backend_test_base import Tf2OnnxBackendTestBase
13+
14+
from tf2onnx import utils
15+
from tf2onnx.tfonnx import process_tf_graph
16+
from tf2onnx import optimizer
17+
18+
from tf2onnx.tflite import Model, OperatorCode, SubGraph, Operator, Tensor, Buffer
19+
from tf2onnx.tflite.BuiltinOperator import BuiltinOperator
20+
from tf2onnx.tflite.TensorType import TensorType
21+
from tf2onnx.tflite.CustomOptionsFormat import CustomOptionsFormat
22+
23+
# pylint: disable=missing-docstring
24+
25+
26+
class TFLiteDetectionPostProcessTests(Tf2OnnxBackendTestBase):
27+
28+
@requires_tflite("TFLite_Detection_PostProcess")
29+
@check_opset_min_version(11, "Pad")
30+
def test_postprocess_model1(self):
31+
self._test_postprocess(num_classes=5, num_boxes=100, detections_per_class=2, max_detections=20)
32+
33+
@requires_tflite("TFLite_Detection_PostProcess")
34+
@check_opset_min_version(11, "Pad")
35+
def test_postprocess_model2(self):
36+
self._test_postprocess(num_classes=5, num_boxes=100, detections_per_class=7, max_detections=20)
37+
38+
@requires_tflite("TFLite_Detection_PostProcess")
39+
@check_opset_min_version(11, "Pad")
40+
def test_postprocess_model3(self):
41+
self._test_postprocess(num_classes=5, num_boxes=3, detections_per_class=7, max_detections=20)
42+
43+
@requires_tflite("TFLite_Detection_PostProcess")
44+
@check_opset_min_version(11, "Pad")
45+
def test_postprocess_model4(self):
46+
self._test_postprocess(num_classes=5, num_boxes=99, detections_per_class=2, max_detections=20, extra_class=True)
47+
48+
def _test_postprocess(self, num_classes, num_boxes, detections_per_class, max_detections, extra_class=False):
49+
model = self.make_postprocess_model(num_classes=num_classes, detections_per_class=detections_per_class,
50+
max_detections=max_detections, x_scale=11.0, w_scale=6.0)
51+
52+
np.random.seed(42)
53+
box_encodings_val = np.random.random_sample([1, num_boxes, 4]).astype(np.float32)
54+
if extra_class:
55+
num_classes += 1
56+
class_predictions_val = np.random.random_sample([1, num_boxes, num_classes]).astype(np.float32)
57+
anchors_val = np.random.random_sample([num_boxes, 4]).astype(np.float32)
58+
59+
feed_dict = {
60+
"box_encodings": box_encodings_val,
61+
"class_predictions": class_predictions_val,
62+
"anchors": anchors_val
63+
}
64+
65+
self.run_tflite_test(model, feed_dict)
66+
67+
def make_postprocess_model(self, max_detections=10, detections_per_class=100, max_classes_per_detection=1,
68+
use_regular_nms=True, nms_score_threshold=0.3, nms_iou_threshold=0.6, num_classes=90,
69+
x_scale=10.0, y_scale=10.0, w_scale=5.0, h_scale=5.0):
70+
"""Returns the bytes of a tflite model containing a single TFLite_Detection_PostProcess op"""
71+
72+
builder = flatbuffers.Builder(1024)
73+
74+
# op_code
75+
custom_code = builder.CreateString("TFLite_Detection_PostProcess")
76+
OperatorCode.OperatorCodeStart(builder)
77+
OperatorCode.OperatorCodeAddDeprecatedBuiltinCode(builder, BuiltinOperator.CUSTOM)
78+
OperatorCode.OperatorCodeAddCustomCode(builder, custom_code)
79+
OperatorCode.OperatorCodeAddBuiltinCode(builder, BuiltinOperator.CUSTOM)
80+
op_code = OperatorCode.OperatorCodeEnd(builder)
81+
82+
# op_codes
83+
Model.ModelStartOperatorCodesVector(builder, 1)
84+
builder.PrependUOffsetTRelative(op_code)
85+
op_codes = builder.EndVector(1)
86+
87+
# Make tensors
88+
# [names, shape, type tensors]
89+
ts = []
90+
inputs_info = [('box_encodings', [-1, -1, 4]), ('class_predictions', [-1, -1, -1]), ('anchors', [-1, 4])]
91+
outputs_info = [
92+
('detection_boxes', [-1, -1, 4]),
93+
('detection_classes', [-1, -1]),
94+
('detection_scores', [-1, -1]),
95+
('num_detections', [-1])
96+
]
97+
for name_info, shape_info in inputs_info + outputs_info:
98+
99+
name = builder.CreateString(name_info)
100+
shape = builder.CreateNumpyVector(np.maximum(np.array(shape_info, np.int32), 1))
101+
shape_signature = builder.CreateNumpyVector(np.array(shape_info, np.int32))
102+
103+
Tensor.TensorStart(builder)
104+
Tensor.TensorAddShape(builder, shape)
105+
Tensor.TensorAddType(builder, TensorType.FLOAT32)
106+
Tensor.TensorAddName(builder, name)
107+
Tensor.TensorAddShapeSignature(builder, shape_signature)
108+
ts.append(Tensor.TensorEnd(builder))
109+
110+
SubGraph.SubGraphStartTensorsVector(builder, len(ts))
111+
for tensor in reversed(ts):
112+
builder.PrependUOffsetTRelative(tensor)
113+
tensors = builder.EndVector(len(ts))
114+
115+
# inputs
116+
SubGraph.SubGraphStartInputsVector(builder, 3)
117+
for inp in reversed([0, 1, 2]):
118+
builder.PrependInt32(inp)
119+
inputs = builder.EndVector(3)
120+
121+
# outputs
122+
SubGraph.SubGraphStartOutputsVector(builder, 4)
123+
for out in reversed([3, 4, 5, 6]):
124+
builder.PrependInt32(out)
125+
outputs = builder.EndVector(4)
126+
127+
flexbuffer = \
128+
b'y_scale\x00nms_score_threshold\x00max_detections\x00x_scale\x00w_scale\x00nms_iou_threshold' \
129+
b'\x00use_regular_nms\x00h_scale\x00max_classes_per_detection\x00num_classes\x00detections_per_class' \
130+
b'\x00\x0b\x16E>\x88j\x9e([v\x7f\xab\x0b\x00\x00\x00\x01\x00\x00\x00\x0b\x00\x00\x00*attr4**attr7*' \
131+
b'*attr10**attr9**attr1**attr2**attr3**attr11*\x00\x00\x00*attr8**attr5**attr6*\x06\x0e\x06\x06\x0e' \
132+
b'\x0e\x06j\x0e\x0e\x0e7&\x01'
133+
flexbuffer = flexbuffer.replace(b'*attr1*', struct.pack('<f', nms_iou_threshold))
134+
flexbuffer = flexbuffer.replace(b'*attr2*', struct.pack('<f', nms_score_threshold))
135+
flexbuffer = flexbuffer.replace(b'*attr3*', struct.pack('<i', num_classes))
136+
flexbuffer = flexbuffer.replace(b'*attr4*', struct.pack('<i', detections_per_class))
137+
flexbuffer = flexbuffer.replace(b'*attr5*', struct.pack('<f', x_scale))
138+
flexbuffer = flexbuffer.replace(b'*attr6*', struct.pack('<f', y_scale))
139+
flexbuffer = flexbuffer.replace(b'*attr7*', struct.pack('<f', h_scale))
140+
flexbuffer = flexbuffer.replace(b'*attr8*', struct.pack('<f', w_scale))
141+
flexbuffer = flexbuffer.replace(b'*attr9*', struct.pack('<i', max_detections))
142+
flexbuffer = flexbuffer.replace(b'*attr10*', struct.pack('<i', max_classes_per_detection))
143+
flexbuffer = flexbuffer.replace(b'*attr11*', struct.pack('<b', use_regular_nms))
144+
145+
custom_options = builder.CreateNumpyVector(np.array(bytearray(flexbuffer)))
146+
147+
# operator
148+
Operator.OperatorStart(builder)
149+
Operator.OperatorAddOpcodeIndex(builder, 0)
150+
Operator.OperatorAddInputs(builder, inputs)
151+
Operator.OperatorAddOutputs(builder, outputs)
152+
Operator.OperatorAddCustomOptions(builder, custom_options)
153+
Operator.OperatorAddCustomOptionsFormat(builder, CustomOptionsFormat.FLEXBUFFERS)
154+
operator = Operator.OperatorEnd(builder)
155+
156+
# operators
157+
SubGraph.SubGraphStartOperatorsVector(builder, 1)
158+
builder.PrependUOffsetTRelative(operator)
159+
operators = builder.EndVector(1)
160+
161+
# subgraph
162+
SubGraph.SubGraphStart(builder)
163+
SubGraph.SubGraphAddTensors(builder, tensors)
164+
SubGraph.SubGraphAddInputs(builder, inputs)
165+
SubGraph.SubGraphAddOutputs(builder, outputs)
166+
SubGraph.SubGraphAddOperators(builder, operators)
167+
subgraph = SubGraph.SubGraphEnd(builder)
168+
169+
# subgraphs
170+
Model.ModelStartSubgraphsVector(builder, 1)
171+
builder.PrependUOffsetTRelative(subgraph)
172+
subgraphs = builder.EndVector(1)
173+
174+
description = builder.CreateString("Model for tflite testing")
175+
176+
Buffer.BufferStartDataVector(builder, 0)
177+
data = builder.EndVector(0)
178+
179+
Buffer.BufferStart(builder)
180+
Buffer.BufferAddData(builder, data)
181+
buffer = Buffer.BufferEnd(builder)
182+
183+
Model.ModelStartBuffersVector(builder, 1)
184+
builder.PrependUOffsetTRelative(buffer)
185+
buffers = builder.EndVector(1)
186+
187+
# model
188+
Model.ModelStart(builder)
189+
Model.ModelAddVersion(builder, 3)
190+
Model.ModelAddOperatorCodes(builder, op_codes)
191+
Model.ModelAddSubgraphs(builder, subgraphs)
192+
Model.ModelAddDescription(builder, description)
193+
Model.ModelAddBuffers(builder, buffers)
194+
model = Model.ModelEnd(builder)
195+
196+
builder.Finish(model, b"TFL3")
197+
return builder.Output()
198+
199+
def run_tflite_test(self, tflite_model, feed_dict, rtol=1e-07, atol=1e-5):
200+
tflite_path = os.path.join(self.test_data_directory, self._testMethodName + ".tflite")
201+
dir_name = os.path.dirname(tflite_path)
202+
if dir_name:
203+
os.makedirs(dir_name, exist_ok=True)
204+
with open(tflite_path, 'wb') as f:
205+
f.write(tflite_model)
206+
tf_lite_output_data, output_names = self.run_tflite(tflite_path, feed_dict)
207+
208+
g = process_tf_graph(None, opset=self.config.opset,
209+
input_names=list(feed_dict.keys()),
210+
output_names=output_names,
211+
target=self.config.target,
212+
tflite_path=tflite_path)
213+
g = optimizer.optimize_graph(g)
214+
onnx_from_tfl_output = self.run_backend(g, output_names, feed_dict, postfix="_from_tflite")
215+
self.assert_results_equal(tf_lite_output_data, onnx_from_tfl_output, rtol, atol)

tf2onnx/tflite_handlers/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
# Copyright (c) Microsoft Corporation. All rights reserved.
2-
# Licensed under the MIT license.
1+
# SPDX-License-Identifier: Apache-2.0
2+
33
"""tf2onnx.tflite_handlers module"""
44

55
from . import (
66
tfl_math,
77
tfl_nn,
88
tfl_controlflow,
99
tfl_direct,
10-
tfl_tensor
10+
tfl_tensor,
11+
tfl_postprocess,
1112
)

tf2onnx/tflite_handlers/tfl_controlflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
# Copyright (c) Microsoft Corporation. All rights reserved.
2-
# Licensed under the MIT license.
1+
# SPDX-License-Identifier: Apache-2.0
2+
33

44
"""
55
tfl_controlflow

tf2onnx/tflite_handlers/tfl_direct.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
# Copyright (c) Microsoft Corporation. All rights reserved.
2-
# Licensed under the MIT license.
1+
# SPDX-License-Identifier: Apache-2.0
2+
33

44
"""
55
tfl_direct

tf2onnx/tflite_handlers/tfl_math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
# Copyright (c) Microsoft Corporation. All rights reserved.
2-
# Licensed under the MIT license.
1+
# SPDX-License-Identifier: Apache-2.0
2+
33

44
"""
55
tfl_math

tf2onnx/tflite_handlers/tfl_nn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
# Copyright (c) Microsoft Corporation. All rights reserved.
2-
# Licensed under the MIT license.
1+
# SPDX-License-Identifier: Apache-2.0
2+
33

44
"""
55
tfl_nn

tf2onnx/tflite_handlers/tfl_postprocess.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import logging
99
import numpy as np
10+
from onnx.onnx_pb import TensorProto
1011

1112
from tf2onnx.handler import tfl_op
1213
from tf2onnx import utils
@@ -76,7 +77,8 @@ def version_11(cls, ctx, node, **kwargs):
7677

7778
nms_inputs = [adjusted_boxes, scores, max_boxes_per_class_const, iou_threshold_const, score_threshold_const]
7879
# shape: [-1, 3], elts of format [batch_index, class_index, box_index]
79-
selected_indices = ctx.make_node('NonMaxSuppression', nms_inputs, attr={'center_point_box': 0}).output[0]
80+
selected_indices = ctx.make_node('NonMaxSuppression', nms_inputs, attr={'center_point_box': 0},
81+
op_name_scope=node.name).output[0]
8082

8183
selected_boxes_idx = GraphBuilder(ctx).make_slice(
8284
{'data': selected_indices, 'starts': [2], 'ends': [3], 'axes': [1]})
@@ -89,15 +91,23 @@ def version_11(cls, ctx, node, **kwargs):
8991
box_and_class_idx = ctx.make_node('Concat', [selected_boxes_idx, selected_classes], attr={'axis': 1}).output[0]
9092

9193
box_cnt = ctx.make_node('Shape', [selected_classes_sq]).output[0]
92-
box_cnt_float = ctx.make_node('Cast', [box_cnt], attr={'to': box_cnt_dtype}).output[0]
9394

9495
adjusted_boxes_sq = GraphBuilder(ctx).make_squeeze({'data': adjusted_boxes, 'axes': [0]})
9596
detection_boxes = ctx.make_node('Gather', [adjusted_boxes_sq, selected_boxes_idx_sq]).output[0]
9697
class_predictions_sq = GraphBuilder(ctx).make_squeeze({'data': class_predictions, 'axes': [0]})
9798
detection_scores = ctx.make_node('GatherND', [class_predictions_sq, box_and_class_idx]).output[0]
9899

99100
k_const = ctx.make_const(utils.make_name('const_k'), np.array([max_detections], np.int64)).output[0]
100-
min_k = ctx.make_node('Min', [k_const, box_cnt]).output[0]
101+
if ctx.opset >= 12:
102+
min_k = ctx.make_node('Min', [k_const, box_cnt]).output[0]
103+
else:
104+
# Lower opsets only support Min between floats
105+
box_cnt_float = ctx.make_node('Cast', [box_cnt], attr={'to': TensorProto.FLOAT}).output[0]
106+
k_const_float = ctx.make_node('Cast', [k_const], attr={'to': TensorProto.FLOAT}).output[0]
107+
min_k_float = ctx.make_node('Min', [k_const_float, box_cnt_float]).output[0]
108+
min_k = ctx.make_node('Cast', [min_k_float], attr={'to': TensorProto.INT64}).output[0]
109+
min_k_cast = ctx.make_node('Cast', [min_k], attr={'to': box_cnt_dtype}).output[0]
110+
101111
scores_top_k, scores_top_k_idx = ctx.make_node('TopK', [detection_scores, min_k], output_count=2).output
102112

103113
scores_top_k_idx_unsq = GraphBuilder(ctx).make_unsqueeze({'data': scores_top_k_idx, 'axes': [0]})
@@ -107,7 +117,7 @@ def version_11(cls, ctx, node, **kwargs):
107117
classes_sort_cast = ctx.make_node('Cast', [selected_classes_sort], attr={'to': classes_dtype}).output[0]
108118
detection_boxes_sorted = ctx.make_node('Gather', [detection_boxes, scores_top_k_idx_unsq]).output[0]
109119

110-
pad_amount = ctx.make_node('Sub', [k_const, box_cnt]).output[0]
120+
pad_amount = ctx.make_node('Sub', [k_const, min_k]).output[0]
111121

112122
quad_zero_const = ctx.make_const(utils.make_name('quad_zero_const'), np.array([0, 0, 0, 0], np.int64)).output[0]
113123
duo_zero_const = ctx.make_const(utils.make_name('duo_zero_const'), np.array([0, 0], np.int64)).output[0]
@@ -123,4 +133,6 @@ def version_11(cls, ctx, node, **kwargs):
123133
ctx.replace_all_inputs(node.output[0], detection_boxes_padded)
124134
ctx.replace_all_inputs(node.output[1], detection_classes_padded)
125135
ctx.replace_all_inputs(node.output[2], detection_scores_padded)
126-
ctx.replace_all_inputs(node.output[3], box_cnt_float)
136+
ctx.replace_all_inputs(node.output[3], min_k_cast)
137+
138+
ctx.remove_node(node.name)

tf2onnx/tflite_handlers/tfl_tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
# Copyright (c) Microsoft Corporation. All rights reserved.
2-
# Licensed under the MIT license.
1+
# SPDX-License-Identifier: Apache-2.0
2+
33

44
"""
55
tfl_tensor

0 commit comments

Comments
 (0)