Skip to content

Commit 05a2a92

Browse files
Added conversion of TFLite_Detection_PostProcess (#1293)
* Added conversion of TFLite_Detection_PostProcess Signed-off-by: Tom Wildenhain <[email protected]> * Fixed formatting Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 8afdae5 commit 05a2a92

File tree

1 file changed

+127
-0
lines changed

1 file changed

+127
-0
lines changed
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
4+
"""
5+
tfl_postprocess
6+
"""
7+
8+
import logging
9+
import numpy as np
10+
11+
from tf2onnx.handler import tfl_op
12+
from tf2onnx import utils
13+
from tf2onnx.graph_builder import GraphBuilder
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
# pylint: disable=unused-argument,missing-docstring,unused-variable,pointless-string-statement,invalid-name
19+
20+
21+
@tfl_op(["TFL_TFLite_Detection_PostProcess"])
22+
class TflDetectionPostProcess:
23+
@classmethod
24+
def version_11(cls, ctx, node, **kwargs):
25+
# This ops is basically NMS with a little post-processing.
26+
# TFLite implementation:
27+
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/kernels/detection_postprocess.cc
28+
29+
# box_encodings.shape = [batch_dim, box_num, 4]
30+
# class_predictions.shape = [batch_dim, box_num, num_classes(+1)]
31+
# anchors.shape = [box_num, 4]
32+
box_encodings, class_predictions, anchors = node.input
33+
34+
classes_dtype = ctx.get_dtype(node.output[1])
35+
box_cnt_dtype = ctx.get_dtype(node.output[3])
36+
37+
num_classes = node.get_attr_value('num_classes')
38+
max_detections = node.get_attr_value('max_detections')
39+
40+
# Remove 'other' class if present.
41+
max_int64 = int(utils.get_max_value(np.int64))
42+
class_predictions = GraphBuilder(ctx).make_slice(
43+
{'data': class_predictions, 'starts': [-num_classes], 'ends': [max_int64], 'axes': [2]})
44+
45+
scaling_vector = [node.get_attr_value(a) for a in ['y_scale', 'x_scale', 'h_scale', 'w_scale']]
46+
scale_const = ctx.make_const(utils.make_name('scale_const'), np.array(scaling_vector, np.float32)).output[0]
47+
48+
scaled_boxes = ctx.make_node('Div', [box_encodings, scale_const]).output[0]
49+
anchors_yx = GraphBuilder(ctx).make_slice({'data': anchors, 'starts': [0], 'ends': [2], 'axes': [1]})
50+
anchors_hw = GraphBuilder(ctx).make_slice({'data': anchors, 'starts': [2], 'ends': [4], 'axes': [1]})
51+
boxes_yx = GraphBuilder(ctx).make_slice({'data': scaled_boxes, 'starts': [0], 'ends': [2], 'axes': [2]})
52+
boxes_hw = GraphBuilder(ctx).make_slice({'data': scaled_boxes, 'starts': [2], 'ends': [4], 'axes': [2]})
53+
54+
scaled_boxes_yx = ctx.make_node('Mul', [boxes_yx, anchors_hw]).output[0]
55+
boxes_hw_exp = ctx.make_node('Exp', [boxes_hw]).output[0]
56+
scaled_boxes_hw = ctx.make_node('Mul', [boxes_hw_exp, anchors_hw]).output[0]
57+
const_half = ctx.make_const(utils.make_name('const_half'), np.array(0.5, np.float32)).output[0]
58+
boxes_half_hw = ctx.make_node('Mul', [scaled_boxes_hw, const_half]).output[0]
59+
boxes_center_yx = ctx.make_node('Add', [scaled_boxes_yx, anchors_yx]).output[0]
60+
61+
boxes_lower_left = ctx.make_node('Sub', [boxes_center_yx, boxes_half_hw]).output[0]
62+
boxes_upper_right = ctx.make_node('Add', [boxes_center_yx, boxes_half_hw]).output[0]
63+
adjusted_boxes = ctx.make_node('Concat', [boxes_lower_left, boxes_upper_right], attr={'axis': 2}).output[0]
64+
65+
iou_threshold = np.array(node.get_attr_value('nms_iou_threshold'), np.float32)
66+
iou_threshold_const = ctx.make_const(utils.make_name('iou_threshold'), iou_threshold).output[0]
67+
68+
score_threshold = np.array(node.get_attr_value('nms_score_threshold'), np.float32)
69+
score_threshold_const = ctx.make_const(utils.make_name('score_threshold'), score_threshold).output[0]
70+
71+
boxes_per_class = np.array(node.get_attr_value('detections_per_class', 100), np.int64)
72+
max_boxes_per_class_const = ctx.make_const(utils.make_name('max_boxes_per_class'), boxes_per_class).output[0]
73+
74+
# scores.shape = [batch_dim, classes_num, box_num]
75+
scores = ctx.make_node('Transpose', [class_predictions], attr={'perm': [0, 2, 1]}).output[0]
76+
77+
nms_inputs = [adjusted_boxes, scores, max_boxes_per_class_const, iou_threshold_const, score_threshold_const]
78+
# shape: [-1, 3], elts of format [batch_index, class_index, box_index]
79+
selected_indices = ctx.make_node('NonMaxSuppression', nms_inputs, attr={'center_point_box': 0}).output[0]
80+
81+
selected_boxes_idx = GraphBuilder(ctx).make_slice(
82+
{'data': selected_indices, 'starts': [2], 'ends': [3], 'axes': [1]})
83+
selected_boxes_idx_sq = GraphBuilder(ctx).make_squeeze({'data': selected_boxes_idx, 'axes': [1]})
84+
85+
selected_classes = GraphBuilder(ctx).make_slice(
86+
{'data': selected_indices, 'starts': [1], 'ends': [2], 'axes': [1]})
87+
selected_classes_sq = GraphBuilder(ctx).make_squeeze({'data': selected_classes, 'axes': [1]})
88+
89+
box_and_class_idx = ctx.make_node('Concat', [selected_boxes_idx, selected_classes], attr={'axis': 1}).output[0]
90+
91+
box_cnt = ctx.make_node('Shape', [selected_classes_sq]).output[0]
92+
93+
box_cnt_float = ctx.make_node('Cast', [box_cnt], attr={'to': box_cnt_dtype}).output[0]
94+
95+
adjusted_boxes_sq = GraphBuilder(ctx).make_squeeze({'data': adjusted_boxes, 'axes': [0]})
96+
detection_boxes = ctx.make_node('Gather', [adjusted_boxes_sq, selected_boxes_idx_sq]).output[0]
97+
class_predictions_sq = GraphBuilder(ctx).make_squeeze({'data': class_predictions, 'axes': [0]})
98+
detection_scores = ctx.make_node('GatherND', [class_predictions_sq, box_and_class_idx]).output[0]
99+
100+
k_const = ctx.make_const(utils.make_name('const_k'), np.array([max_detections], np.int64)).output[0]
101+
min_k = ctx.make_node('Min', [k_const, box_cnt]).output[0]
102+
scores_top_k, scores_top_k_idx = ctx.make_node('TopK', [detection_scores, min_k], output_count=2).output
103+
104+
scores_top_k_idx_unsq = GraphBuilder(ctx).make_unsqueeze({'data': scores_top_k_idx, 'axes': [0]})
105+
scores_top_k_unsq = GraphBuilder(ctx).make_unsqueeze({'data': scores_top_k, 'axes': [0]})
106+
107+
selected_classes_sort = ctx.make_node('Gather', [selected_classes_sq, scores_top_k_idx_unsq]).output[0]
108+
classes_sort_cast = ctx.make_node('Cast', [selected_classes_sort], attr={'to': classes_dtype}).output[0]
109+
detection_boxes_sorted = ctx.make_node('Gather', [detection_boxes, scores_top_k_idx_unsq]).output[0]
110+
111+
pad_amount = ctx.make_node('Sub', [k_const, box_cnt]).output[0]
112+
113+
quad_zero_const = ctx.make_const(utils.make_name('quad_zero_const'), np.array([0, 0, 0, 0], np.int64)).output[0]
114+
duo_zero_const = ctx.make_const(utils.make_name('duo_zero_const'), np.array([0, 0], np.int64)).output[0]
115+
zero_const = ctx.make_const(utils.make_name('zero_const'), np.array([0], np.int64)).output[0]
116+
117+
pads_3d = ctx.make_node('Concat', [quad_zero_const, pad_amount, zero_const], attr={'axis': 0}).output[0]
118+
pads_2d = ctx.make_node('Concat', [duo_zero_const, zero_const, pad_amount], attr={'axis': 0}).output[0]
119+
120+
detection_boxes_padded = ctx.make_node('Pad', [detection_boxes_sorted, pads_3d]).output[0]
121+
detection_classes_padded = ctx.make_node('Pad', [classes_sort_cast, pads_2d]).output[0]
122+
detection_scores_padded = ctx.make_node('Pad', [scores_top_k_unsq, pads_2d]).output[0]
123+
124+
ctx.replace_all_inputs(node.output[0], detection_boxes_padded)
125+
ctx.replace_all_inputs(node.output[1], detection_classes_padded)
126+
ctx.replace_all_inputs(node.output[2], detection_scores_padded)
127+
ctx.replace_all_inputs(node.output[3], box_cnt_float)

0 commit comments

Comments
 (0)