Skip to content

Commit 20b2aff

Browse files
[Fix] Fix batch inference error for Mask R-CNN (#1575)
* fix Mask R-CNN for multi-batch * fix flake8 * fix bug * use Sequence instead list * fix docstring * only test_img accept list input
1 parent 20e0563 commit 20b2aff

File tree

3 files changed

+19
-12
lines changed

3 files changed

+19
-12
lines changed

mmdeploy/apis/visualize.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
def visualize_model(model_cfg: Union[str, mmcv.Config],
1212
deploy_cfg: Union[str, mmcv.Config],
1313
model: Union[str, Sequence[str]],
14-
img: Union[str, np.ndarray],
14+
img: Union[str, np.ndarray, Sequence[str]],
1515
device: str,
1616
backend: Optional[Backend] = None,
1717
output_file: Optional[str] = None,
@@ -36,7 +36,8 @@ def visualize_model(model_cfg: Union[str, mmcv.Config],
3636
deploy_cfg (str | mmcv.Config): Deployment config file or Config
3737
object.
3838
model (str | list[str], BaseSubtask): Input model or file(s).
39-
img (str | np.ndarray): Input image file or numpy array for inference.
39+
img (str | np.ndarray | Sequence[str]): Input image file(s) or numpy
40+
array for inference.
4041
device (str): A string specifying device type.
4142
backend (Backend): Specifying backend type, defaults to `None`.
4243
output_file (str): Output file to save visualized image, defaults to
@@ -74,14 +75,16 @@ def visualize_model(model_cfg: Union[str, mmcv.Config],
7475
# check headless
7576
import tkinter
7677
tkinter.Tk()
77-
78-
task_processor.visualize(
79-
image=img,
80-
model=model,
81-
result=result,
82-
output_file=output_file,
83-
window_name=backend.value,
84-
show_result=show_result)
78+
if not isinstance(img, Sequence):
79+
img = [img]
80+
for single_img in img:
81+
task_processor.visualize(
82+
image=single_img,
83+
model=model,
84+
result=result,
85+
output_file=output_file,
86+
window_name=backend.value,
87+
show_result=show_result)
8588
except Exception as e:
8689
from mmdeploy.utils import get_root_logger
8790
logger = get_root_logger()

mmdeploy/codebase/mmdet/models/roi_heads/test_mixins.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def mask_test_mixin__simple_test_mask(ctx, self, x, img_metas, det_bboxes,
9696
# expand might lead to static shape, use broadcast instead
9797
batch_index = torch.arange(
9898
det_bboxes.size(0), device=det_bboxes.device).float().view(
99-
-1, 1) + det_bboxes.new_zeros(
99+
-1, 1, 1) + det_bboxes.new_zeros(
100100
(det_bboxes.size(0), det_bboxes.size(1))).unsqueeze(-1)
101101
mask_rois = torch.cat([batch_index, det_bboxes], dim=-1)
102102
mask_rois = mask_rois.view(-1, 5)

tools/deploy.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@ def parse_args():
2727
parser.add_argument('checkpoint', help='model checkpoint path')
2828
parser.add_argument('img', help='image used to convert model model')
2929
parser.add_argument(
30-
'--test-img', default=None, help='image used to test model')
30+
'--test-img',
31+
default=None,
32+
type=str,
33+
nargs='+',
34+
help='image used to test model')
3135
parser.add_argument(
3236
'--work-dir',
3337
default=os.getcwd(),

0 commit comments

Comments
 (0)