|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | + |
| 3 | + |
| 4 | +""" |
| 5 | +tfl_postprocess |
| 6 | +""" |
| 7 | + |
| 8 | +import logging |
| 9 | +import numpy as np |
| 10 | + |
| 11 | +from tf2onnx.handler import tfl_op |
| 12 | +from tf2onnx import utils |
| 13 | +from tf2onnx.graph_builder import GraphBuilder |
| 14 | + |
| 15 | +logger = logging.getLogger(__name__) |
| 16 | + |
| 17 | + |
| 18 | +# pylint: disable=unused-argument,missing-docstring,unused-variable,pointless-string-statement,invalid-name |
| 19 | + |
| 20 | + |
| 21 | +@tfl_op(["TFL_TFLite_Detection_PostProcess"]) |
| 22 | +class TflDetectionPostProcess: |
| 23 | + @classmethod |
| 24 | + def version_11(cls, ctx, node, **kwargs): |
| 25 | + # This ops is basically NMS with a little post-processing. |
| 26 | + # TFLite implementation: |
| 27 | + # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/kernels/detection_postprocess.cc |
| 28 | + |
| 29 | + # box_encodings.shape = [batch_dim, box_num, 4] |
| 30 | + # class_predictions.shape = [batch_dim, box_num, num_classes(+1)] |
| 31 | + # anchors.shape = [box_num, 4] |
| 32 | + box_encodings, class_predictions, anchors = node.input |
| 33 | + |
| 34 | + classes_dtype = ctx.get_dtype(node.output[1]) |
| 35 | + box_cnt_dtype = ctx.get_dtype(node.output[3]) |
| 36 | + |
| 37 | + num_classes = node.get_attr_value('num_classes') |
| 38 | + max_detections = node.get_attr_value('max_detections') |
| 39 | + |
| 40 | + # Remove 'other' class if present. |
| 41 | + max_int64 = int(utils.get_max_value(np.int64)) |
| 42 | + class_predictions = GraphBuilder(ctx).make_slice( |
| 43 | + {'data': class_predictions, 'starts': [-num_classes], 'ends': [max_int64], 'axes': [2]}) |
| 44 | + |
| 45 | + scaling_vector = [node.get_attr_value(a) for a in ['y_scale', 'x_scale', 'h_scale', 'w_scale']] |
| 46 | + scale_const = ctx.make_const(utils.make_name('scale_const'), np.array(scaling_vector, np.float32)).output[0] |
| 47 | + |
| 48 | + scaled_boxes = ctx.make_node('Div', [box_encodings, scale_const]).output[0] |
| 49 | + anchors_yx = GraphBuilder(ctx).make_slice({'data': anchors, 'starts': [0], 'ends': [2], 'axes': [1]}) |
| 50 | + anchors_hw = GraphBuilder(ctx).make_slice({'data': anchors, 'starts': [2], 'ends': [4], 'axes': [1]}) |
| 51 | + boxes_yx = GraphBuilder(ctx).make_slice({'data': scaled_boxes, 'starts': [0], 'ends': [2], 'axes': [2]}) |
| 52 | + boxes_hw = GraphBuilder(ctx).make_slice({'data': scaled_boxes, 'starts': [2], 'ends': [4], 'axes': [2]}) |
| 53 | + |
| 54 | + scaled_boxes_yx = ctx.make_node('Mul', [boxes_yx, anchors_hw]).output[0] |
| 55 | + boxes_hw_exp = ctx.make_node('Exp', [boxes_hw]).output[0] |
| 56 | + scaled_boxes_hw = ctx.make_node('Mul', [boxes_hw_exp, anchors_hw]).output[0] |
| 57 | + const_half = ctx.make_const(utils.make_name('const_half'), np.array(0.5, np.float32)).output[0] |
| 58 | + boxes_half_hw = ctx.make_node('Mul', [scaled_boxes_hw, const_half]).output[0] |
| 59 | + boxes_center_yx = ctx.make_node('Add', [scaled_boxes_yx, anchors_yx]).output[0] |
| 60 | + |
| 61 | + boxes_lower_left = ctx.make_node('Sub', [boxes_center_yx, boxes_half_hw]).output[0] |
| 62 | + boxes_upper_right = ctx.make_node('Add', [boxes_center_yx, boxes_half_hw]).output[0] |
| 63 | + adjusted_boxes = ctx.make_node('Concat', [boxes_lower_left, boxes_upper_right], attr={'axis': 2}).output[0] |
| 64 | + |
| 65 | + iou_threshold = np.array(node.get_attr_value('nms_iou_threshold'), np.float32) |
| 66 | + iou_threshold_const = ctx.make_const(utils.make_name('iou_threshold'), iou_threshold).output[0] |
| 67 | + |
| 68 | + score_threshold = np.array(node.get_attr_value('nms_score_threshold'), np.float32) |
| 69 | + score_threshold_const = ctx.make_const(utils.make_name('score_threshold'), score_threshold).output[0] |
| 70 | + |
| 71 | + boxes_per_class = np.array(node.get_attr_value('detections_per_class', 100), np.int64) |
| 72 | + max_boxes_per_class_const = ctx.make_const(utils.make_name('max_boxes_per_class'), boxes_per_class).output[0] |
| 73 | + |
| 74 | + # scores.shape = [batch_dim, classes_num, box_num] |
| 75 | + scores = ctx.make_node('Transpose', [class_predictions], attr={'perm': [0, 2, 1]}).output[0] |
| 76 | + |
| 77 | + nms_inputs = [adjusted_boxes, scores, max_boxes_per_class_const, iou_threshold_const, score_threshold_const] |
| 78 | + # shape: [-1, 3], elts of format [batch_index, class_index, box_index] |
| 79 | + selected_indices = ctx.make_node('NonMaxSuppression', nms_inputs, attr={'center_point_box': 0}).output[0] |
| 80 | + |
| 81 | + selected_boxes_idx = GraphBuilder(ctx).make_slice( |
| 82 | + {'data': selected_indices, 'starts': [2], 'ends': [3], 'axes': [1]}) |
| 83 | + selected_boxes_idx_sq = GraphBuilder(ctx).make_squeeze({'data': selected_boxes_idx, 'axes': [1]}) |
| 84 | + |
| 85 | + selected_classes = GraphBuilder(ctx).make_slice( |
| 86 | + {'data': selected_indices, 'starts': [1], 'ends': [2], 'axes': [1]}) |
| 87 | + selected_classes_sq = GraphBuilder(ctx).make_squeeze({'data': selected_classes, 'axes': [1]}) |
| 88 | + |
| 89 | + box_and_class_idx = ctx.make_node('Concat', [selected_boxes_idx, selected_classes], attr={'axis': 1}).output[0] |
| 90 | + |
| 91 | + box_cnt = ctx.make_node('Shape', [selected_classes_sq]).output[0] |
| 92 | + |
| 93 | + box_cnt_float = ctx.make_node('Cast', [box_cnt], attr={'to': box_cnt_dtype}).output[0] |
| 94 | + |
| 95 | + adjusted_boxes_sq = GraphBuilder(ctx).make_squeeze({'data': adjusted_boxes, 'axes': [0]}) |
| 96 | + detection_boxes = ctx.make_node('Gather', [adjusted_boxes_sq, selected_boxes_idx_sq]).output[0] |
| 97 | + class_predictions_sq = GraphBuilder(ctx).make_squeeze({'data': class_predictions, 'axes': [0]}) |
| 98 | + detection_scores = ctx.make_node('GatherND', [class_predictions_sq, box_and_class_idx]).output[0] |
| 99 | + |
| 100 | + k_const = ctx.make_const(utils.make_name('const_k'), np.array([max_detections], np.int64)).output[0] |
| 101 | + min_k = ctx.make_node('Min', [k_const, box_cnt]).output[0] |
| 102 | + scores_top_k, scores_top_k_idx = ctx.make_node('TopK', [detection_scores, min_k], output_count=2).output |
| 103 | + |
| 104 | + scores_top_k_idx_unsq = GraphBuilder(ctx).make_unsqueeze({'data': scores_top_k_idx, 'axes': [0]}) |
| 105 | + scores_top_k_unsq = GraphBuilder(ctx).make_unsqueeze({'data': scores_top_k, 'axes': [0]}) |
| 106 | + |
| 107 | + selected_classes_sort = ctx.make_node('Gather', [selected_classes_sq, scores_top_k_idx_unsq]).output[0] |
| 108 | + classes_sort_cast = ctx.make_node('Cast', [selected_classes_sort], attr={'to': classes_dtype}).output[0] |
| 109 | + detection_boxes_sorted = ctx.make_node('Gather', [detection_boxes, scores_top_k_idx_unsq]).output[0] |
| 110 | + |
| 111 | + pad_amount = ctx.make_node('Sub', [k_const, box_cnt]).output[0] |
| 112 | + |
| 113 | + quad_zero_const = ctx.make_const(utils.make_name('quad_zero_const'), np.array([0, 0, 0, 0], np.int64)).output[0] |
| 114 | + duo_zero_const = ctx.make_const(utils.make_name('duo_zero_const'), np.array([0, 0], np.int64)).output[0] |
| 115 | + zero_const = ctx.make_const(utils.make_name('zero_const'), np.array([0], np.int64)).output[0] |
| 116 | + |
| 117 | + pads_3d = ctx.make_node('Concat', [quad_zero_const, pad_amount, zero_const], attr={'axis': 0}).output[0] |
| 118 | + pads_2d = ctx.make_node('Concat', [duo_zero_const, zero_const, pad_amount], attr={'axis': 0}).output[0] |
| 119 | + |
| 120 | + detection_boxes_padded = ctx.make_node('Pad', [detection_boxes_sorted, pads_3d]).output[0] |
| 121 | + detection_classes_padded = ctx.make_node('Pad', [classes_sort_cast, pads_2d]).output[0] |
| 122 | + detection_scores_padded = ctx.make_node('Pad', [scores_top_k_unsq, pads_2d]).output[0] |
| 123 | + |
| 124 | + ctx.replace_all_inputs(node.output[0], detection_boxes_padded) |
| 125 | + ctx.replace_all_inputs(node.output[1], detection_classes_padded) |
| 126 | + ctx.replace_all_inputs(node.output[2], detection_scores_padded) |
| 127 | + ctx.replace_all_inputs(node.output[3], box_cnt_float) |
0 commit comments