Skip to content

Commit c4dc10d

Browse files
authored
add condinst ut & update docs (open-mmlab#2481)
1 parent 4c376d9 commit c4dc10d

File tree

3 files changed

+221
-0
lines changed

3 files changed

+221
-0
lines changed

docs/en/04-supported-codebases/mmdet.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ Besides python API, mmdeploy SDK also provides other FFI (Foreign Function Inter
218218
| [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/main/configs/swin) | Instance Segmentation | Y | Y | N | N | Y |
219219
| [SOLO](https://github.com/open-mmlab/mmdetection/tree/main/configs/solo) | Instance Segmentation | Y | N | N | N | Y |
220220
| [SOLOv2](https://github.com/open-mmlab/mmdetection/tree/main/configs/solov2) | Instance Segmentation | Y | N | N | N | Y |
221+
| [CondInst](https://github.com/open-mmlab/mmdetection/tree/main/configs/condinst) | Instance Segmentation | Y | Y | N | N | N |
221222
| [Panoptic FPN](https://github.com/open-mmlab/mmdetection/tree/main/configs/panoptic_fpn) | Panoptic Segmentation | Y | Y | N | N | N |
222223
| [MaskFormer](https://github.com/open-mmlab/mmdetection/tree/main/configs/maskformer) | Panoptic Segmentation | Y | Y | N | N | N |
223224
| [Mask2Former](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask2former)[\*](#mask2former) | Panoptic Segmentation | Y | Y | N | N | N |

docs/zh_cn/04-supported-codebases/mmdet.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
- [后端模型推理](#后端模型推理)
1111
- [SDK 模型推理](#sdk-模型推理)
1212
- [模型支持列表](#模型支持列表)
13+
- [注意事项](#注意事项)
1314

1415
______________________________________________________________________
1516

@@ -220,6 +221,7 @@ cv2.imwrite('output_detection.png', img)
220221
| [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/main/configs/swin) | Instance Segmentation | Y | Y | N | N | Y |
221222
| [SOLO](https://github.com/open-mmlab/mmdetection/tree/main/configs/solo) | Instance Segmentation | Y | N | N | N | Y |
222223
| [SOLOv2](https://github.com/open-mmlab/mmdetection/tree/main/configs/solov2) | Instance Segmentation | Y | N | N | N | Y |
224+
| [CondInst](https://github.com/open-mmlab/mmdetection/tree/main/configs/condinst) | Instance Segmentation | Y | Y | N | N | N |
223225
| [Panoptic FPN](https://github.com/open-mmlab/mmdetection/tree/main/configs/panoptic_fpn) | Panoptic Segmentation | Y | Y | N | N | N |
224226
| [MaskFormer](https://github.com/open-mmlab/mmdetection/tree/main/configs/maskformer) | Panoptic Segmentation | Y | Y | N | N | N |
225227
| [Mask2Former](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask2former)[\*](#mask2former) | Panoptic Segmentation | Y | Y | N | N | N |

tests/test_codebase/test_mmdet/test_mmdet_models.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2364,3 +2364,221 @@ def test_solov2_head_predict_by_feat(backend_type):
23642364
atol=1e-05)
23652365
else:
23662366
assert rewrite_outputs is not None
2367+
2368+
2369+
def get_condinst_bbox_head():
2370+
"""condinst Bbox Head Config."""
2371+
test_cfg = Config(
2372+
dict(
2373+
mask_thr=0.5,
2374+
max_per_img=100,
2375+
min_bbox_size=0,
2376+
nms=dict(iou_threshold=0.6, type='nms'),
2377+
nms_pre=1000,
2378+
score_thr=0.05))
2379+
from mmdet.models.dense_heads import CondInstBboxHead
2380+
model = CondInstBboxHead(
2381+
center_sampling=True,
2382+
centerness_on_reg=True,
2383+
conv_bias=True,
2384+
dcn_on_last_conv=False,
2385+
feat_channels=256,
2386+
in_channels=256,
2387+
loss_bbox=dict(loss_weight=1.0, type='GIoULoss'),
2388+
loss_centerness=dict(
2389+
loss_weight=1.0, type='CrossEntropyLoss', use_sigmoid=True),
2390+
loss_cls=dict(
2391+
alpha=0.25,
2392+
gamma=2.0,
2393+
loss_weight=1.0,
2394+
type='FocalLoss',
2395+
use_sigmoid=True),
2396+
norm_on_bbox=True,
2397+
num_classes=80,
2398+
num_params=169,
2399+
stacked_convs=4,
2400+
strides=[
2401+
8,
2402+
16,
2403+
32,
2404+
64,
2405+
128,
2406+
],
2407+
test_cfg=test_cfg,
2408+
)
2409+
2410+
model.requires_grad_(False)
2411+
return model
2412+
2413+
2414+
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
2415+
def test_condinst_bbox_head_predict_by_feat(backend_type):
2416+
"""Test predict_by_feat rewrite of condinst bbox head."""
2417+
check_backend(backend_type)
2418+
condinst_bbox_head = get_condinst_bbox_head()
2419+
condinst_bbox_head.cpu().eval()
2420+
s = 128
2421+
batch_img_metas = [{
2422+
'scale_factor': np.ones(4),
2423+
'pad_shape': (s, s, 3),
2424+
'img_shape': (s, s, 3)
2425+
}]
2426+
2427+
output_names = ['dets', 'labels', 'param_preds', 'points', 'strides']
2428+
deploy_cfg = Config(
2429+
dict(
2430+
backend_config=dict(type=backend_type.value),
2431+
onnx_config=dict(output_names=output_names, input_shape=None),
2432+
codebase_config=dict(
2433+
type='mmdet',
2434+
task='ObjectDetection',
2435+
post_processing=dict(
2436+
score_threshold=0.05,
2437+
confidence_threshold=0.005,
2438+
iou_threshold=0.5,
2439+
max_output_boxes_per_class=200,
2440+
pre_top_k=5000,
2441+
keep_top_k=100,
2442+
background_label_id=-1,
2443+
export_postprocess_mask=False))))
2444+
2445+
seed_everything(1234)
2446+
cls_scores = [
2447+
torch.rand(1, condinst_bbox_head.num_classes, pow(2, i), pow(2, i))
2448+
for i in range(5, 0, -1)
2449+
]
2450+
seed_everything(5678)
2451+
bbox_preds = [
2452+
torch.rand(1, 4, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
2453+
]
2454+
seed_everything(9101)
2455+
score_factors = [
2456+
torch.rand(1, 1, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
2457+
]
2458+
seed_everything(1121)
2459+
param_preds = [
2460+
torch.rand(1, condinst_bbox_head.num_params, pow(2, i), pow(2, i))
2461+
for i in range(5, 0, -1)
2462+
]
2463+
2464+
# to get outputs of onnx model after rewrite
2465+
wrapped_model = WrapModel(
2466+
condinst_bbox_head, 'predict_by_feat', batch_img_metas=batch_img_metas)
2467+
rewrite_inputs = {
2468+
'cls_scores': cls_scores,
2469+
'bbox_preds': bbox_preds,
2470+
'score_factors': score_factors,
2471+
'param_preds': param_preds,
2472+
}
2473+
rewrite_outputs, is_backend_output = get_rewrite_outputs(
2474+
wrapped_model=wrapped_model,
2475+
model_inputs=rewrite_inputs,
2476+
deploy_cfg=deploy_cfg)
2477+
2478+
if is_backend_output:
2479+
dets = rewrite_outputs[0]
2480+
labels = rewrite_outputs[1]
2481+
param_preds = rewrite_outputs[2]
2482+
points = rewrite_outputs[3]
2483+
strides = rewrite_outputs[4]
2484+
assert dets.shape[-1] == 5
2485+
assert labels is not None
2486+
assert param_preds.shape[-1] == condinst_bbox_head.num_params
2487+
assert points.shape[-1] == 2
2488+
assert strides is not None
2489+
else:
2490+
assert rewrite_outputs is not None
2491+
2492+
2493+
def get_condinst_mask_head():
2494+
"""condinst Mask Head Config."""
2495+
test_cfg = Config(
2496+
dict(
2497+
mask_thr=0.5,
2498+
max_per_img=100,
2499+
min_bbox_size=0,
2500+
nms=dict(iou_threshold=0.6, type='nms'),
2501+
nms_pre=1000,
2502+
score_thr=0.05))
2503+
from mmdet.models.dense_heads import CondInstMaskHead
2504+
model = CondInstMaskHead(
2505+
mask_feature_head=dict(
2506+
end_level=2,
2507+
feat_channels=128,
2508+
in_channels=256,
2509+
mask_stride=8,
2510+
norm_cfg=dict(requires_grad=True, type='BN'),
2511+
num_stacked_convs=4,
2512+
out_channels=8,
2513+
start_level=0),
2514+
num_layers=3,
2515+
feat_channels=8,
2516+
mask_out_stride=4,
2517+
size_of_interest=8,
2518+
max_masks_to_train=300,
2519+
loss_mask=dict(
2520+
activate=True,
2521+
eps=5e-06,
2522+
loss_weight=1.0,
2523+
type='DiceLoss',
2524+
use_sigmoid=True),
2525+
test_cfg=test_cfg,
2526+
)
2527+
2528+
model.requires_grad_(False)
2529+
return model
2530+
2531+
2532+
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
2533+
def test_condinst_mask_head_forward(backend_type):
2534+
"""Test predict_by_feat rewrite of condinst mask head."""
2535+
check_backend(backend_type)
2536+
2537+
output_names = ['mask_preds']
2538+
deploy_cfg = Config(
2539+
dict(
2540+
backend_config=dict(type=backend_type.value),
2541+
onnx_config=dict(output_names=output_names, input_shape=None),
2542+
codebase_config=dict(type='mmdet', task='ObjectDetection')))
2543+
2544+
class TestCondInstMaskHeadModel(torch.nn.Module):
2545+
2546+
def __init__(self, condinst_mask_head):
2547+
super(TestCondInstMaskHeadModel, self).__init__()
2548+
self.mask_head = condinst_mask_head
2549+
2550+
def forward(self, x, param_preds, points, strides):
2551+
positive_infos = dict(
2552+
param_preds=param_preds, points=points, strides=strides)
2553+
return self.mask_head(x, positive_infos)
2554+
2555+
mask_head = get_condinst_mask_head()
2556+
level = mask_head.mask_feature_head.end_level - \
2557+
mask_head.mask_feature_head.start_level + 1
2558+
2559+
condinst_mask_head = TestCondInstMaskHeadModel(mask_head)
2560+
condinst_mask_head.cpu().eval()
2561+
2562+
seed_everything(1234)
2563+
x = [torch.rand(1, 256, pow(2, i), pow(2, i)) for i in range(level, 0, -1)]
2564+
seed_everything(5678)
2565+
param_preds = torch.rand(1, 100, 169)
2566+
seed_everything(9101)
2567+
points = torch.rand(1, 100, 2)
2568+
seed_everything(1121)
2569+
strides = torch.rand(1, 100)
2570+
2571+
# to get outputs of onnx model after rewrite
2572+
wrapped_model = WrapModel(condinst_mask_head, 'forward')
2573+
rewrite_inputs = {
2574+
'x': x,
2575+
'param_preds': param_preds,
2576+
'points': points,
2577+
'strides': strides
2578+
}
2579+
rewrite_outputs, _ = get_rewrite_outputs(
2580+
wrapped_model=wrapped_model,
2581+
model_inputs=rewrite_inputs,
2582+
deploy_cfg=deploy_cfg)
2583+
2584+
assert rewrite_outputs is not None

0 commit comments

Comments
 (0)