11# Copyright (c) OpenMMLab. All rights reserved.
22import numpy as np
33import torch
4- from mmcv .ops import nms_rotated
4+ from mmcv .ops import nms , nms_rotated
55
66
7- def merge_results (results , offsets , iou_thr = 0.1 , device = 'cpu' ):
7+ def translate_bboxes (bboxes , offset ):
8+ """Translate bboxes according to its shape.
9+
10+ If the bbox shape is (n, 5), the bboxes are regarded as horizontal bboxes
11+ and in (x, y, x, y, score) format. If the bbox shape is (n, 6), the bboxes
12+ are regarded as rotated bboxes and in (x, y, w, h, theta, score) format.
13+
14+ Args:
15+ bboxes (np.ndarray): The bboxes need to be translated. Its shape can
16+ only be (n, 5) and (n, 6).
17+ offset (np.ndarray): The offset to translate with shape being (2, ).
18+
19+ Returns:
20+ np.ndarray: Translated bboxes.
21+ """
22+ if bboxes .shape [1 ] == 5 :
23+ bboxes [:, :4 ] = bboxes [:, :4 ] + np .tile (offset , 2 )
24+ elif bboxes .shape [1 ] == 6 :
25+ bboxes [:, :2 ] = bboxes [:, :2 ] + offset
26+ else :
27+ raise TypeError ('Require the shape of `bboxes` to be (n, 5) or (n, 6),'
28+ f' but get `bboxes` with shape being { bboxes .shape } .' )
29+ return bboxes
30+
31+
32+ def map_masks (masks , offset , new_shape ):
33+ """Map masks to the huge image.
34+
35+ Args:
36+ masks (list[np.ndarray]): masks need to be mapped.
37+ offset (np.ndarray): The offset to translate with shape being (2, ).
38+ new_shape (tuple): A tuple of the huge image's width and height.
39+
40+ Returns:
41+ list[np.ndarray]: Mapped masks.
42+ """
43+ if not masks :
44+ return masks
45+
46+ new_width , new_height = new_shape
47+ x_start , y_start = offset
48+ mapped = []
49+ for mask in masks :
50+ ori_height , ori_width = mask .shape [:2 ]
51+
52+ x_end = x_start + ori_width
53+ if x_end > new_width :
54+ ori_width -= x_end - new_width
55+ x_end = new_width
56+
57+ y_end = y_start + ori_height
58+ if y_end > new_height :
59+ ori_height -= y_end - new_height
60+ y_end = new_height
61+
62+ extended_mask = np .zeros ((new_height , new_width ), dtype = np .bool )
63+ extended_mask [y_start :y_end ,
64+ x_start :x_end ] = mask [:ori_height , :ori_width ]
65+ mapped .append (extended_mask )
66+ return mapped
67+
68+
69+ def merge_results (results , offsets , img_shape , iou_thr = 0.1 , device = 'cpu' ):
870 """Merge patch results via nms.
971
1072 Args:
11- results (list[np.ndarray]): A list of patches results.
73+ results (list[np.ndarray] | list[tuple] ): A list of patches results.
1274 offsets (np.ndarray): Positions of the left top points of patches.
75+ img_shape (tuple): A tuple of the huge image's width and height.
1376 iou_thr (float): The IoU threshold of NMS.
1477 device (str): The device to call nms.
1578
@@ -18,20 +81,47 @@ def merge_results(results, offsets, iou_thr=0.1, device='cpu'):
1881 """
1982 assert len (results ) == offsets .shape [0 ], 'The `results` should has the ' \
2083 'same length with `offsets`.'
21- merged_results = []
22- for results_pre_cls in zip ( * results ):
23- tran_dets = []
24- for dets , offset in zip ( results_pre_cls , offsets ):
25- dets [:, : 2 ] += offset
26- tran_dets . append ( dets )
27- tran_dets = np . concatenate ( tran_dets , axis = 0 )
28-
29- if tran_dets . size == 0 :
30- merged_results . append ( tran_dets )
84+ with_mask = isinstance ( results [ 0 ], tuple )
85+ num_patches = len ( results )
86+ num_classes = len ( results [ 0 ][ 0 ]) if with_mask else len ( results [ 0 ])
87+
88+ merged_bboxes = []
89+ merged_masks = []
90+ for cls in range ( num_classes ):
91+ if with_mask :
92+ dets_per_cls = [ results [ i ][ 0 ][ cls ] for i in range ( num_patches )]
93+ masks_per_cls = [ results [ i ][ 1 ][ cls ] for i in range ( num_patches )]
3194 else :
32- tran_dets = torch .from_numpy (tran_dets )
33- tran_dets = tran_dets .to (device )
34- nms_dets , _ = nms_rotated (tran_dets [:, :5 ], tran_dets [:, - 1 ],
35- iou_thr )
36- merged_results .append (nms_dets .cpu ().numpy ())
37- return merged_results
95+ dets_per_cls = [results [i ][cls ] for i in range (num_patches )]
96+ masks_per_cls = None
97+
98+ dets_per_cls = [
99+ translate_bboxes (dets_per_cls [i ], offsets [i ])
100+ for i in range (num_patches )
101+ ]
102+ dets_per_cls = np .concatenate (dets_per_cls , axis = 0 )
103+ if with_mask :
104+ masks_placeholder = []
105+ for i , masks in enumerate (masks_per_cls ):
106+ translated = map_masks (masks , offsets [i ], img_shape )
107+ masks_placeholder .extend (translated )
108+ masks_per_cls = masks_placeholder
109+
110+ if dets_per_cls .size == 0 :
111+ merged_bboxes .append (dets_per_cls )
112+ if with_mask :
113+ merged_masks .append (masks_per_cls )
114+ else :
115+ dets_per_cls = torch .from_numpy (dets_per_cls ).to (device )
116+ nms_func = nms if dets_per_cls .size (1 ) == 5 else nms_rotated
117+ nms_dets , keeps = nms_func (dets_per_cls [:, :- 1 ],
118+ dets_per_cls [:, - 1 ], iou_thr )
119+ merged_bboxes .append (nms_dets .cpu ().numpy ())
120+ if with_mask :
121+ keeps = keeps .cpu ().numpy ()
122+ merged_masks .append ([masks_per_cls [i ] for i in keeps ])
123+
124+ if with_mask :
125+ return merged_bboxes , merged_masks
126+ else :
127+ return merged_bboxes
0 commit comments