Skip to content

Commit 5ebd10b

Browse files
authored
Add with argmax in config for mmseg (#2038)
* add with_argmax for model conversion in mmseg * resolve lint
1 parent e9c0092 commit 5ebd10b

File tree

8 files changed

+20
-63
lines changed

8 files changed

+20
-63
lines changed

configs/mmseg/segmentation_rknn-fp16_static-320x320.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
onnx_config = dict(input_shape=[320, 320])
44

5-
codebase_config = dict(model_type='rknn')
5+
codebase_config = dict(with_argmax=False)
66

77
backend_config = dict(
88
input_size_list=[[3, 320, 320]],

configs/mmseg/segmentation_rknn-int8_static-320x320.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22

33
onnx_config = dict(input_shape=[320, 320])
44

5-
codebase_config = dict(model_type='rknn')
5+
codebase_config = dict(with_argmax=False)
66

77
backend_config = dict(input_size_list=[[3, 320, 320]])
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
_base_ = ['../_base_/onnx_config.py']
2-
codebase_config = dict(type='mmseg', task='Segmentation')
2+
codebase_config = dict(type='mmseg', task='Segmentation', with_argmax=True)

docs/en/04-supported-codebases/mmseg.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,3 +231,5 @@ Besides python API, mmdeploy SDK also provides other FFI (Foreign Function Inter
231231
- <i id="static_shape">PSPNet, Fast-SCNN</i> only support static shape, because [nn.AdaptiveAvgPool2d](https://github.com/open-mmlab/mmsegmentation/blob/0c87f7a0c9099844eff8e90fa3db5b0d0ca02fee/mmseg/models/decode_heads/psp_head.py#L38) is not supported by most inference backends.
232232

233233
- For models that only supports static shape, you should use the deployment config file of static shape such as `configs/mmseg/segmentation_tensorrt_static-1024x2048.py`.
234+
235+
- For users prefer deployed models generate probability feature map, put `codebase_config = dict(with_argmax=False)` in deploy configs.

docs/zh_cn/04-supported-codebases/mmseg.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,3 +235,5 @@ cv2.imwrite('output_segmentation.png', img)
235235
- <i id=“static_shape”>PSPNet,Fast-SCNN</i> 仅支持静态输入,因为多数推理框架的 [nn.AdaptiveAvgPool2d](https://github.com/open-mmlab/mmsegmentation/blob/0c87f7a0c9099844eff8e90fa3db5b0d0ca02fee/mmseg/models/decode_heads/psp_head.py#L38) 不支持动态输入。
236236

237237
- 对于仅支持静态形状的模型,应使用静态形状的部署配置文件,例如 `configs/mmseg/segmentation_tensorrt_static-1024x2048.py`
238+
239+
- 对于喜欢部署模型生成概率特征图的用户,将 `codebase_config = dict(with_argmax=False)` 放在部署配置中就足够了。

mmdeploy/codebase/mmseg/deploy/segmentation.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from mmengine.registry import Registry
1414

1515
from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase
16-
from mmdeploy.utils import Codebase, Task, get_input_shape, get_root_logger
16+
from mmdeploy.utils import (Codebase, Task, get_codebase_config,
17+
get_input_shape, get_root_logger)
1718

1819

1920
def process_model_config(model_cfg: mmengine.Config,
@@ -303,6 +304,9 @@ def get_postprocess(self, *args, **kwargs) -> Dict:
303304
if isinstance(params, list):
304305
params = params[-1]
305306
postprocess = dict(params=params, type='ResizeMask')
307+
with_argmax = get_codebase_config(self.deploy_cfg).get(
308+
'with_argmax', True)
309+
postprocess['with_argmax'] = with_argmax
306310
return postprocess
307311

308312
def get_model_name(self, *args, **kwargs) -> str:

mmdeploy/codebase/mmseg/deploy/segmentation_model.py

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ def pack_result(self, batch_outputs: torch.Tensor,
105105
for seg_pred, data_sample in zip(batch_outputs, data_samples):
106106
# resize seg_pred to original image shape
107107
metainfo = data_sample.metainfo
108+
if get_codebase_config(self.deploy_cfg).get('with_argmax',
109+
True) is False:
110+
seg_pred = seg_pred.argmax(dim=0, keepdim=True)
108111
if metainfo['ori_shape'] != metainfo['img_shape']:
109112
from mmseg.models.utils import resize
110113
ori_type = seg_pred.dtype
@@ -119,39 +122,6 @@ def pack_result(self, batch_outputs: torch.Tensor,
119122
return predictions
120123

121124

122-
@__BACKEND_MODEL.register_module('rknn')
123-
class RKNNModel(End2EndModel):
124-
"""SDK inference class, converts RKNN output to mmseg format."""
125-
126-
def forward(self,
127-
inputs: torch.Tensor,
128-
data_samples: Optional[List[BaseDataElement]] = None,
129-
mode: str = 'predict'):
130-
"""Run forward inference.
131-
132-
Args:
133-
inputs (Tensor): Inputs with shape (N, C, H, W).
134-
data_samples (list[:obj:`SegDataSample`]): The seg data
135-
samples. It usually includes information such as
136-
`metainfo` and `gt_sem_seg`. Default to None.
137-
138-
Returns:
139-
list: A list contains predictions.
140-
"""
141-
assert mode == 'predict', \
142-
'Backend model only support mode==predict,' f' but get {mode}'
143-
if inputs.device != torch.device(self.device):
144-
get_root_logger().warning(f'expect input device {self.device}'
145-
f' but get {inputs.device}.')
146-
inputs = inputs.to(self.device)
147-
batch_outputs = self.wrapper({self.input_name: inputs})
148-
batch_outputs = [
149-
output.argmax(dim=1, keepdim=True)
150-
for output in batch_outputs.values()
151-
]
152-
return self.pack_result(batch_outputs[0], data_samples)
153-
154-
155125
@__BACKEND_MODEL.register_module('vacc_seg')
156126
class VACCModel(End2EndModel):
157127
"""SDK inference class, converts VACC output to mmseg format."""
Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from mmdeploy.core import FUNCTION_REWRITER, mark
3-
from mmdeploy.utils.constants import Backend
3+
from mmdeploy.utils import get_codebase_config
44

55

66
@FUNCTION_REWRITER.register_rewriter(
@@ -26,6 +26,10 @@ def encoder_decoder__predict(self, inputs, data_samples, **kwargs):
2626
x = self.extract_feat(inputs)
2727
seg_logit = self.decode_head.predict(x, batch_img_metas, self.test_cfg)
2828

29+
ctx = FUNCTION_REWRITER.get_context()
30+
if get_codebase_config(ctx.cfg).get('with_argmax', True) is False:
31+
return seg_logit
32+
2933
# mark seg_head
3034
@mark('decode_head', outputs=['output'])
3135
def __mark_seg_logit(seg_logit):
@@ -35,28 +39,3 @@ def __mark_seg_logit(seg_logit):
3539

3640
seg_pred = seg_logit.argmax(dim=1, keepdim=True)
3741
return seg_pred
38-
39-
40-
@FUNCTION_REWRITER.register_rewriter(
41-
func_name='mmseg.models.segmentors.EncoderDecoder.predict',
42-
backend=Backend.RKNN.value)
43-
def encoder_decoder__predict__rknn(self, inputs, data_samples, **kwargs):
44-
"""Rewrite `predict` for RKNN backend.
45-
46-
Early return to avoid argmax operator.
47-
48-
Args:
49-
ctx (ContextCaller): The context with additional information.
50-
self: The instance of the original class.
51-
inputs (Tensor): Inputs with shape (N, C, H, W).
52-
data_samples (SampleList): The seg data samples.
53-
54-
Returns:
55-
torch.Tensor: Output segmentation map pf shape [N, 1, H, W].
56-
"""
57-
batch_img_metas = []
58-
for data_sample in data_samples:
59-
batch_img_metas.append(data_sample.metainfo)
60-
x = self.extract_feat(inputs)
61-
seg_logit = self.decode_head.predict(x, batch_img_metas, self.test_cfg)
62-
return seg_logit

0 commit comments

Comments
 (0)