Skip to content

Commit 0edd257

Browse files
authored
[Enhance] Support mask in merge_results and huge_image_demo.py. (#280)
* Support masks mergeing * Update error report
1 parent d80310a commit 0edd257

File tree

2 files changed

+114
-20
lines changed

2 files changed

+114
-20
lines changed

mmrotate/apis/inference.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,5 +85,9 @@ def inference_detector_by_patches(model,
8585
start += bs
8686

8787
results = merge_results(
88-
results, windows[:, :2], iou_thr=merge_iou_thr, device=device)
88+
results,
89+
windows[:, :2],
90+
img_shape=(width, height),
91+
iou_thr=merge_iou_thr,
92+
device=device)
8993
return results
Lines changed: 109 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,78 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import numpy as np
33
import torch
4-
from mmcv.ops import nms_rotated
4+
from mmcv.ops import nms, nms_rotated
55

66

7-
def merge_results(results, offsets, iou_thr=0.1, device='cpu'):
7+
def translate_bboxes(bboxes, offset):
8+
"""Translate bboxes according to its shape.
9+
10+
If the bbox shape is (n, 5), the bboxes are regarded as horizontal bboxes
11+
and in (x, y, x, y, score) format. If the bbox shape is (n, 6), the bboxes
12+
are regarded as rotated bboxes and in (x, y, w, h, theta, score) format.
13+
14+
Args:
15+
bboxes (np.ndarray): The bboxes need to be translated. Its shape can
16+
only be (n, 5) and (n, 6).
17+
offset (np.ndarray): The offset to translate with shape being (2, ).
18+
19+
Returns:
20+
np.ndarray: Translated bboxes.
21+
"""
22+
if bboxes.shape[1] == 5:
23+
bboxes[:, :4] = bboxes[:, :4] + np.tile(offset, 2)
24+
elif bboxes.shape[1] == 6:
25+
bboxes[:, :2] = bboxes[:, :2] + offset
26+
else:
27+
raise TypeError('Require the shape of `bboxes` to be (n, 5) or (n, 6),'
28+
f' but get `bboxes` with shape being {bboxes.shape}.')
29+
return bboxes
30+
31+
32+
def map_masks(masks, offset, new_shape):
33+
"""Map masks to the huge image.
34+
35+
Args:
36+
masks (list[np.ndarray]): masks need to be mapped.
37+
offset (np.ndarray): The offset to translate with shape being (2, ).
38+
new_shape (tuple): A tuple of the huge image's width and height.
39+
40+
Returns:
41+
list[np.ndarray]: Mapped masks.
42+
"""
43+
if not masks:
44+
return masks
45+
46+
new_width, new_height = new_shape
47+
x_start, y_start = offset
48+
mapped = []
49+
for mask in masks:
50+
ori_height, ori_width = mask.shape[:2]
51+
52+
x_end = x_start + ori_width
53+
if x_end > new_width:
54+
ori_width -= x_end - new_width
55+
x_end = new_width
56+
57+
y_end = y_start + ori_height
58+
if y_end > new_height:
59+
ori_height -= y_end - new_height
60+
y_end = new_height
61+
62+
extended_mask = np.zeros((new_height, new_width), dtype=np.bool)
63+
extended_mask[y_start:y_end,
64+
x_start:x_end] = mask[:ori_height, :ori_width]
65+
mapped.append(extended_mask)
66+
return mapped
67+
68+
69+
def merge_results(results, offsets, img_shape, iou_thr=0.1, device='cpu'):
870
"""Merge patch results via nms.
971
1072
Args:
11-
results (list[np.ndarray]): A list of patches results.
73+
results (list[np.ndarray] | list[tuple]): A list of patches results.
1274
offsets (np.ndarray): Positions of the left top points of patches.
75+
img_shape (tuple): A tuple of the huge image's width and height.
1376
iou_thr (float): The IoU threshold of NMS.
1477
device (str): The device to call nms.
1578
@@ -18,20 +81,47 @@ def merge_results(results, offsets, iou_thr=0.1, device='cpu'):
1881
"""
1982
assert len(results) == offsets.shape[0], 'The `results` should has the ' \
2083
'same length with `offsets`.'
21-
merged_results = []
22-
for results_pre_cls in zip(*results):
23-
tran_dets = []
24-
for dets, offset in zip(results_pre_cls, offsets):
25-
dets[:, :2] += offset
26-
tran_dets.append(dets)
27-
tran_dets = np.concatenate(tran_dets, axis=0)
28-
29-
if tran_dets.size == 0:
30-
merged_results.append(tran_dets)
84+
with_mask = isinstance(results[0], tuple)
85+
num_patches = len(results)
86+
num_classes = len(results[0][0]) if with_mask else len(results[0])
87+
88+
merged_bboxes = []
89+
merged_masks = []
90+
for cls in range(num_classes):
91+
if with_mask:
92+
dets_per_cls = [results[i][0][cls] for i in range(num_patches)]
93+
masks_per_cls = [results[i][1][cls] for i in range(num_patches)]
3194
else:
32-
tran_dets = torch.from_numpy(tran_dets)
33-
tran_dets = tran_dets.to(device)
34-
nms_dets, _ = nms_rotated(tran_dets[:, :5], tran_dets[:, -1],
35-
iou_thr)
36-
merged_results.append(nms_dets.cpu().numpy())
37-
return merged_results
95+
dets_per_cls = [results[i][cls] for i in range(num_patches)]
96+
masks_per_cls = None
97+
98+
dets_per_cls = [
99+
translate_bboxes(dets_per_cls[i], offsets[i])
100+
for i in range(num_patches)
101+
]
102+
dets_per_cls = np.concatenate(dets_per_cls, axis=0)
103+
if with_mask:
104+
masks_placeholder = []
105+
for i, masks in enumerate(masks_per_cls):
106+
translated = map_masks(masks, offsets[i], img_shape)
107+
masks_placeholder.extend(translated)
108+
masks_per_cls = masks_placeholder
109+
110+
if dets_per_cls.size == 0:
111+
merged_bboxes.append(dets_per_cls)
112+
if with_mask:
113+
merged_masks.append(masks_per_cls)
114+
else:
115+
dets_per_cls = torch.from_numpy(dets_per_cls).to(device)
116+
nms_func = nms if dets_per_cls.size(1) == 5 else nms_rotated
117+
nms_dets, keeps = nms_func(dets_per_cls[:, :-1],
118+
dets_per_cls[:, -1], iou_thr)
119+
merged_bboxes.append(nms_dets.cpu().numpy())
120+
if with_mask:
121+
keeps = keeps.cpu().numpy()
122+
merged_masks.append([masks_per_cls[i] for i in keeps])
123+
124+
if with_mask:
125+
return merged_bboxes, merged_masks
126+
else:
127+
return merged_bboxes

0 commit comments

Comments
 (0)