|
| 1 | +# Copyright (c) OpenMMLab. All rights reserved. |
| 2 | +from typing import Dict, List, Optional |
| 3 | + |
| 4 | +import torch |
| 5 | +from mmdet.models.utils import aligned_bilinear |
| 6 | +from mmengine.config import ConfigDict |
| 7 | +from torch import Tensor |
| 8 | + |
| 9 | +from mmdeploy.codebase.mmdet.deploy import get_post_processing_params |
| 10 | +from mmdeploy.core import FUNCTION_REWRITER |
| 11 | +from mmdeploy.mmcv.ops.nms import multiclass_nms |
| 12 | + |
| 13 | + |
| 14 | +@FUNCTION_REWRITER.register_rewriter( |
| 15 | + 'mmdet.models.dense_heads.CondInstBboxHead.predict_by_feat') |
| 16 | +def condinst_bbox_head__predict_by_feat( |
| 17 | + self, |
| 18 | + cls_scores: List[Tensor], |
| 19 | + bbox_preds: List[Tensor], |
| 20 | + score_factors: Optional[List[Tensor]] = None, |
| 21 | + param_preds: Optional[List[Tensor]] = None, |
| 22 | + batch_img_metas: Optional[List[dict]] = None, |
| 23 | + cfg: Optional[ConfigDict] = None, |
| 24 | + rescale: bool = False, |
| 25 | + with_nms: bool = True, |
| 26 | +): |
| 27 | + ctx = FUNCTION_REWRITER.get_context() |
| 28 | + deploy_cfg = ctx.cfg |
| 29 | + |
| 30 | + assert len(cls_scores) == len(bbox_preds) |
| 31 | + device = bbox_preds[0].device |
| 32 | + cfg = self.test_cfg if cfg is None else cfg |
| 33 | + batch_size = bbox_preds[0].shape[0] |
| 34 | + featmap_sizes = [cls_score.shape[-2:] for cls_score in cls_scores] |
| 35 | + |
| 36 | + all_level_points_strides = self.prior_generator.grid_priors( |
| 37 | + featmap_sizes, device=device, with_stride=True) |
| 38 | + all_level_points = [i[:, :2] for i in all_level_points_strides] |
| 39 | + all_level_strides = [i[:, 2] for i in all_level_points_strides] |
| 40 | + |
| 41 | + flatten_cls_scores = [ |
| 42 | + cls_score.permute(0, 2, 3, 1).reshape(batch_size, -1, |
| 43 | + self.cls_out_channels) |
| 44 | + for cls_score in cls_scores |
| 45 | + ] |
| 46 | + flatten_bbox_preds = [ |
| 47 | + bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4) |
| 48 | + for bbox_pred in bbox_preds |
| 49 | + ] |
| 50 | + flatten_score_factors = [ |
| 51 | + score_factor.permute(0, 2, 3, 1).reshape(batch_size, -1, 1) |
| 52 | + for score_factor in score_factors |
| 53 | + ] |
| 54 | + flatten_param_preds = [ |
| 55 | + param_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, self.num_params) |
| 56 | + for param_pred in param_preds |
| 57 | + ] |
| 58 | + flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid() |
| 59 | + flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1) |
| 60 | + flatten_score_factors = torch.cat(flatten_score_factors, dim=1).sigmoid() |
| 61 | + flatten_param_preds = torch.cat(flatten_param_preds, dim=1) |
| 62 | + |
| 63 | + points = torch.cat(all_level_points) |
| 64 | + strides = torch.cat(all_level_strides) |
| 65 | + tl_x = points[..., 0] - flatten_bbox_preds[..., 0] |
| 66 | + tl_y = points[..., 1] - flatten_bbox_preds[..., 1] |
| 67 | + br_x = points[..., 0] + flatten_bbox_preds[..., 2] |
| 68 | + br_y = points[..., 1] + flatten_bbox_preds[..., 3] |
| 69 | + |
| 70 | + bboxes = torch.stack([tl_x, tl_y, br_x, br_y], -1) |
| 71 | + scores = flatten_cls_scores |
| 72 | + score_factors = flatten_score_factors |
| 73 | + param_preds = flatten_param_preds |
| 74 | + scores = scores * score_factors |
| 75 | + |
| 76 | + # get post processing config |
| 77 | + post_params = get_post_processing_params(deploy_cfg) |
| 78 | + max_output_boxes_per_class = post_params.max_output_boxes_per_class |
| 79 | + iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold) |
| 80 | + score_threshold = cfg.get('score_thr', post_params.score_threshold) |
| 81 | + pre_top_k = post_params.pre_top_k |
| 82 | + keep_top_k = cfg.get('max_per_img', post_params.keep_top_k) |
| 83 | + |
| 84 | + dets, labels, inds = multiclass_nms( |
| 85 | + bboxes, |
| 86 | + scores, |
| 87 | + max_output_boxes_per_class, |
| 88 | + iou_threshold, |
| 89 | + score_threshold, |
| 90 | + pre_top_k=pre_top_k, |
| 91 | + keep_top_k=keep_top_k, |
| 92 | + output_index=True, |
| 93 | + ) |
| 94 | + |
| 95 | + batch_inds = torch.arange(batch_size, device=bboxes.device).view(-1, 1) |
| 96 | + points = points.unsqueeze(0).repeat(batch_size, 1, 1) |
| 97 | + strides = strides.unsqueeze(0).repeat(batch_size, 1) |
| 98 | + param_preds = param_preds[batch_inds, inds, :] |
| 99 | + points = points[batch_inds, inds, :] |
| 100 | + strides = strides[batch_inds, inds] |
| 101 | + results = dict( |
| 102 | + dets=dets, |
| 103 | + labels=labels, |
| 104 | + param_preds=param_preds, |
| 105 | + points=points, |
| 106 | + strides=strides) |
| 107 | + return results |
| 108 | + |
| 109 | + |
| 110 | +@FUNCTION_REWRITER.register_rewriter( |
| 111 | + 'mmdet.models.dense_heads.CondInstMaskHead.forward') |
| 112 | +def condinst_mask_head__forward(self, x: tuple, |
| 113 | + positive_infos: Dict[str, torch.Tensor]): |
| 114 | + mask_feats = self.mask_feature_head(x) |
| 115 | + |
| 116 | + param_preds = positive_infos['param_preds'] |
| 117 | + points = positive_infos['points'] |
| 118 | + strides = positive_infos['strides'] |
| 119 | + |
| 120 | + batch_size = points.shape[0] |
| 121 | + num_insts = points.shape[1] |
| 122 | + hw = mask_feats.size()[-2:] |
| 123 | + mask_feats = mask_feats.unsqueeze(1).repeat(1, num_insts, 1, 1, 1) |
| 124 | + |
| 125 | + points = points.reshape(-1, 1, 2).unsqueeze(0) |
| 126 | + locations = self.prior_generator.single_level_grid_priors( |
| 127 | + hw, level_idx=0, device=mask_feats.device) |
| 128 | + locations = locations.unsqueeze(0).repeat(batch_size, 1, |
| 129 | + 1).reshape(batch_size, 1, -1, 2) |
| 130 | + centers = points.reshape(batch_size, -1, 1, 2) |
| 131 | + rel_coordinates = (centers - locations).permute(0, 1, 3, 2).float() |
| 132 | + rel_coordinates /= (strides[:, :, None, None] * self.size_of_interest) |
| 133 | + rel_coords = rel_coordinates.reshape(batch_size, -1, 2, hw[0], hw[1]) |
| 134 | + mask_head_inputs = torch.cat([rel_coords, mask_feats], dim=2) |
| 135 | + |
| 136 | + weights, biases = _parse_dynamic_params(self, param_preds) |
| 137 | + mask_preds = _dynamic_conv_forward(mask_head_inputs, weights, biases) |
| 138 | + mask_preds = mask_preds.reshape(batch_size, num_insts, hw[0], hw[1]) |
| 139 | + mask_preds = aligned_bilinear( |
| 140 | + mask_preds, int(self.mask_feat_stride / self.mask_out_stride)) |
| 141 | + return (mask_preds, ) |
| 142 | + |
| 143 | + |
| 144 | +@FUNCTION_REWRITER.register_rewriter( |
| 145 | + 'mmdet.models.dense_heads.CondInstMaskHead.predict_by_feat') |
| 146 | +def condinst_mask_head__predict_by_feat(self, |
| 147 | + mask_preds: Tensor, |
| 148 | + results_list: Dict[str, torch.Tensor], |
| 149 | + batch_img_metas: List[dict], |
| 150 | + rescale: bool = True, |
| 151 | + **kwargs): |
| 152 | + cfg = self.test_cfg |
| 153 | + |
| 154 | + dets = results_list['dets'] |
| 155 | + labels = results_list['labels'] |
| 156 | + img_hw = batch_img_metas[0]['img_shape'][:2] |
| 157 | + |
| 158 | + mask_preds = mask_preds.sigmoid() |
| 159 | + mask_preds = aligned_bilinear(mask_preds, self.mask_out_stride) |
| 160 | + mask_preds = mask_preds[:, :, :img_hw[0], :img_hw[1]] |
| 161 | + masks = (mask_preds > cfg.mask_thr).float() |
| 162 | + |
| 163 | + return dets, labels, masks |
| 164 | + |
| 165 | + |
| 166 | +def _parse_dynamic_params(self, params: Tensor): |
| 167 | + """parse the dynamic params for dynamic conv.""" |
| 168 | + batch_size = params.shape[0] |
| 169 | + num_insts = params.shape[1] |
| 170 | + params = params.permute(1, 0, 2) |
| 171 | + params_splits = list( |
| 172 | + torch.split_with_sizes( |
| 173 | + params, self.weight_nums + self.bias_nums, dim=2)) |
| 174 | + |
| 175 | + weight_splits = params_splits[:self.num_layers] |
| 176 | + bias_splits = params_splits[self.num_layers:] |
| 177 | + |
| 178 | + for idx in range(self.num_layers): |
| 179 | + if idx < self.num_layers - 1: |
| 180 | + weight_splits[idx] = weight_splits[idx].reshape( |
| 181 | + batch_size, num_insts, self.in_channels, -1) |
| 182 | + else: |
| 183 | + weight_splits[idx] = weight_splits[idx].reshape( |
| 184 | + batch_size, num_insts, 1, -1) |
| 185 | + |
| 186 | + return weight_splits, bias_splits |
| 187 | + |
| 188 | + |
| 189 | +def _dynamic_conv_forward(features: Tensor, weights: List[Tensor], |
| 190 | + biases: List[Tensor]): |
| 191 | + """dynamic forward, each layer follow a relu.""" |
| 192 | + n_layers = len(weights) |
| 193 | + x = features.flatten(0, 1).flatten(2) |
| 194 | + for i, (w, b) in enumerate(zip(weights, biases)): |
| 195 | + # replace dynamic conv with bmm |
| 196 | + w = w.flatten(0, 1) |
| 197 | + b = b.flatten(0, 1).unsqueeze(2) |
| 198 | + x = torch.bmm(w, x) |
| 199 | + x = x + b |
| 200 | + if i < n_layers - 1: |
| 201 | + x = x.clamp_(min=0) |
| 202 | + return x |
0 commit comments