|
| 1 | +# Copyright (c) OpenMMLab. All rights reserved. |
| 2 | +from typing import List, Tuple |
| 3 | + |
| 4 | +import torch |
| 5 | +from torch import Tensor |
| 6 | + |
| 7 | +from mmdeploy.core import FUNCTION_REWRITER |
| 8 | + |
| 9 | + |
| 10 | +@FUNCTION_REWRITER.register_rewriter( |
| 11 | + 'mmdet.models.roi_heads.htc_roi_head.HybridTaskCascadeRoIHead.predict_mask' |
| 12 | +) |
| 13 | +def htc_roi_head__predict_mask(self, |
| 14 | + x: Tuple[Tensor], |
| 15 | + semantic_heat: Tensor, |
| 16 | + batch_img_metas: List[dict], |
| 17 | + results_list: List[Tensor], |
| 18 | + rescale: bool = False) -> List[Tensor]: |
| 19 | + dets, det_labels = results_list |
| 20 | + |
| 21 | + batch_size = dets.size(0) |
| 22 | + det_bboxes = dets[..., :4] |
| 23 | + batch_index = torch.arange( |
| 24 | + det_bboxes.size(0), |
| 25 | + device=det_bboxes.device).float().view(-1, 1, 1).expand( |
| 26 | + det_bboxes.size(0), det_bboxes.size(1), 1) |
| 27 | + mask_rois = torch.cat([batch_index, det_bboxes], dim=-1) |
| 28 | + mask_rois = mask_rois.view(-1, 5) |
| 29 | + |
| 30 | + mask_results = self._mask_forward( |
| 31 | + stage=-1, |
| 32 | + x=x, |
| 33 | + rois=mask_rois, |
| 34 | + semantic_feat=semantic_heat, |
| 35 | + training=False) |
| 36 | + |
| 37 | + mask_preds = mask_results['mask_preds'][0] |
| 38 | + num_det = det_bboxes.shape[1] |
| 39 | + segm_results = self.mask_head[-1].predict_by_feat( |
| 40 | + mask_preds, |
| 41 | + results_list, |
| 42 | + batch_img_metas, |
| 43 | + self.test_cfg, |
| 44 | + rescale=rescale) |
| 45 | + segm_results = segm_results.reshape(batch_size, num_det, |
| 46 | + segm_results.shape[-2], |
| 47 | + segm_results.shape[-1]) |
| 48 | + return dets, det_labels, segm_results |
0 commit comments