Skip to content

Commit a0780a8

Browse files
author
Songki Choi
authored
Make max_num_detections configurable (#2647)
* Make max_num_detections configurable * Fix RCNN case with integration test * Apply max_num_detections to train_cfg, too --------- Signed-off-by: Songki Choi <[email protected]>
1 parent a2545f9 commit a0780a8

File tree

18 files changed

+145
-91
lines changed

18 files changed

+145
-91
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ All notable changes to this project will be documented in this file.
99
- Update ModelAPI configuration(<https://github.com/openvinotoolkit/training_extensions/pull/2564>)
1010
- Add Anomaly modelAPI changes (<https://github.com/openvinotoolkit/training_extensions/pull/2563>)
1111
- Update Image numpy access (<https://github.com/openvinotoolkit/training_extensions/pull/2586>)
12+
- Make max_num_detections configurable (<https://github.com/openvinotoolkit/training_extensions/pull/2647>)
1213

1314
### Bug fixes
1415

src/otx/algorithms/common/configs/training_base.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,7 @@
11
"""Base Configuration of OTX Common Algorithms."""
22

3-
# Copyright (C) 2022 Intel Corporation
4-
#
5-
# Licensed under the Apache License, Version 2.0 (the "License");
6-
# you may not use this file except in compliance with the License.
7-
# You may obtain a copy of the License at
8-
#
9-
# http://www.apache.org/licenses/LICENSE-2.0
10-
#
11-
# Unless required by applicable law or agreed to in writing,
12-
# software distributed under the License is distributed on an "AS IS" BASIS,
13-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14-
# See the License for the specific language governing permissions
15-
# and limitations under the License.
3+
# Copyright (C) 2022-2023 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
165

176
from sys import maxsize
187

@@ -227,6 +216,16 @@ class BasePostprocessing(ParameterGroup):
227216
affects_outcome_of=ModelLifecycle.INFERENCE,
228217
)
229218

219+
max_num_detections = configurable_integer(
220+
header="Maximum number of detection per image",
221+
description="Extra detection outputs will be discared in non-maximum suppression process. "
222+
"Defaults to 0, which means per-model default value.",
223+
default_value=0,
224+
min_value=0,
225+
max_value=10000,
226+
affects_outcome_of=ModelLifecycle.INFERENCE,
227+
)
228+
230229
use_ellipse_shapes = configurable_boolean(
231230
default_value=False,
232231
header="Use ellipse shapes",

src/otx/algorithms/detection/adapters/mmdet/configurer.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,14 @@ def configure(
6464
ir_options=None,
6565
data_classes=None,
6666
model_classes=None,
67+
max_num_detections=0,
6768
):
6869
"""Create MMCV-consumable config from given inputs."""
6970
logger.info(f"configure!: training={training}")
7071

7172
self.configure_base(cfg, data_cfg, data_classes, model_classes)
7273
self.configure_device(cfg, training)
73-
self.configure_model(cfg, ir_options)
74+
self.configure_model(cfg, ir_options, max_num_detections)
7475
self.configure_ckpt(cfg, model_ckpt)
7576
self.configure_data(cfg, training, data_cfg)
7677
self.configure_regularization(cfg, training)
@@ -113,7 +114,7 @@ def configure_base(self, cfg, data_cfg, data_classes, model_classes):
113114
new_classes = np.setdiff1d(data_classes, model_classes).tolist()
114115
train_data_cfg["new_classes"] = new_classes
115116

116-
def configure_model(self, cfg, ir_options): # noqa: C901
117+
def configure_model(self, cfg, ir_options, max_num_detections=0): # noqa: C901
117118
"""Patch config's model.
118119
119120
Change model type to super type
@@ -149,6 +150,23 @@ def is_mmov_model(key, value):
149150
{"model_path": ir_model_path, "weight_path": ir_weight_path, "init_weight": ir_weight_init},
150151
)
151152

153+
# Test config
154+
if max_num_detections > 0:
155+
logger.info(f"Model max_num_detections: {max_num_detections}")
156+
test_cfg = cfg.model.test_cfg
157+
test_cfg.max_per_img = max_num_detections
158+
test_cfg.nms_pre = max_num_detections * 10
159+
# Special cases for 2-stage detectors (e.g. MaskRCNN)
160+
if hasattr(test_cfg, "rpn"):
161+
test_cfg.rpn.nms_pre = max_num_detections * 20
162+
test_cfg.rpn.max_per_img = max_num_detections * 10
163+
if hasattr(test_cfg, "rcnn"):
164+
test_cfg.rcnn.max_per_img = max_num_detections
165+
train_cfg = cfg.model.train_cfg
166+
if hasattr(train_cfg, "rpn_proposal"):
167+
train_cfg.rpn_proposal.nms_pre = max_num_detections * 20
168+
train_cfg.rpn_proposal.max_per_img = max_num_detections * 10
169+
152170
def configure_data(self, cfg, training, data_cfg): # noqa: C901
153171
"""Patch cfg.data.
154172

src/otx/algorithms/detection/adapters/mmdet/task.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,7 @@
11
"""Task of OTX Detection using mmdetection training backend."""
22

33
# Copyright (C) 2023 Intel Corporation
4-
#
5-
# Licensed under the Apache License, Version 2.0 (the "License");
6-
# you may not use this file except in compliance with the License.
7-
# You may obtain a copy of the License at
8-
#
9-
# http://www.apache.org/licenses/LICENSE-2.0
10-
#
11-
# Unless required by applicable law or agreed to in writing,
12-
# software distributed under the License is distributed on an "AS IS" BASIS,
13-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14-
# See the License for the specific language governing permissions
15-
# and limitations under the License.
4+
# SPDX-License-Identifier: Apache-2.0
165

176
import glob
187
import io
@@ -206,6 +195,7 @@ def configure(self, training=True, subset="train", ir_options=None, train_datase
206195
ir_options,
207196
data_classes,
208197
model_classes,
198+
self.max_num_detections,
209199
)
210200
if should_cluster_anchors(self._recipe_cfg):
211201
if train_dataset is not None:
@@ -513,6 +503,12 @@ def _export_model(
513503
assert len(self._precision) == 1
514504
export_options["precision"] = str(self._precision[0])
515505
export_options["type"] = str(export_format)
506+
if self.max_num_detections > 0:
507+
logger.info(f"Export max_num_detections: {self.max_num_detections}")
508+
post_proc_cfg = export_options["deploy_cfg"]["codebase_config"]["post_processing"]
509+
post_proc_cfg["max_output_boxes_per_class"] = self.max_num_detections
510+
post_proc_cfg["keep_top_k"] = self.max_num_detections
511+
post_proc_cfg["pre_top_k"] = self.max_num_detections * 10
516512

517513
export_options["deploy_cfg"]["dump_features"] = dump_features
518514
if dump_features:

src/otx/algorithms/detection/adapters/openvino/task.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,7 @@
11
"""Openvino Task of Detection."""
22

3-
# Copyright (C) 2021 Intel Corporation
4-
#
5-
# Licensed under the Apache License, Version 2.0 (the "License");
6-
# you may not use this file except in compliance with the License.
7-
# You may obtain a copy of the License at
8-
#
9-
# http://www.apache.org/licenses/LICENSE-2.0
10-
#
11-
# Unless required by applicable law or agreed to in writing,
12-
# software distributed under the License is distributed on an "AS IS" BASIS,
13-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14-
# See the License for the specific language governing permissions
15-
# and limitations under the License.
3+
# Copyright (C) 2021-2023 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
165

176
import copy
187
import io

src/otx/algorithms/detection/configs/base/configuration.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,7 @@
11
"""Configuration file of OTX Detection."""
22

3-
# Copyright (C) 2022 Intel Corporation
4-
#
5-
# Licensed under the Apache License, Version 2.0 (the "License");
6-
# you may not use this file except in compliance with the License.
7-
# You may obtain a copy of the License at
8-
#
9-
# http://www.apache.org/licenses/LICENSE-2.0
10-
#
11-
# Unless required by applicable law or agreed to in writing,
12-
# software distributed under the License is distributed on an "AS IS" BASIS,
13-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14-
# See the License for the specific language governing permissions
15-
# and limitations under the License.
3+
# Copyright (C) 2022-2023 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
165

176
from attr import attrs
187

src/otx/algorithms/detection/configs/detection/configuration.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,25 @@ postprocessing:
258258
value: 0.01
259259
visible_in_ui: true
260260
warning: null
261+
max_num_detections:
262+
affects_outcome_of: INFERENCE
263+
default_value: 0
264+
description:
265+
Extra detection outputs will be discared in non-maximum suppression process.
266+
Defaults to 0, which means per-model default values.
267+
editable: true
268+
header: Maximum number of detections per image
269+
max_value: 10000
270+
min_value: 0
271+
type: INTEGER
272+
ui_rules:
273+
action: DISABLE_EDITING
274+
operator: AND
275+
rules: []
276+
type: UI_RULES
277+
value: 0
278+
visible_in_ui: true
279+
warning: null
261280
use_ellipse_shapes:
262281
affects_outcome_of: INFERENCE
263282
default_value: false

src/otx/algorithms/detection/configs/instance_segmentation/configuration.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,25 @@ postprocessing:
258258
value: 0.01
259259
visible_in_ui: true
260260
warning: null
261+
max_num_detections:
262+
affects_outcome_of: INFERENCE
263+
default_value: 0
264+
description:
265+
Extra detection outputs will be discared in non-maximum suppression process.
266+
Defaults to 0, which means per-model default values.
267+
editable: true
268+
header: Maximum number of detections per image
269+
max_value: 10000
270+
min_value: 0
271+
type: INTEGER
272+
ui_rules:
273+
action: DISABLE_EDITING
274+
operator: AND
275+
rules: []
276+
type: UI_RULES
277+
value: 0
278+
visible_in_ui: true
279+
warning: null
261280
use_ellipse_shapes:
262281
affects_outcome_of: INFERENCE
263282
default_value: false

src/otx/algorithms/detection/configs/instance_segmentation/convnext_maskrcnn/model.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,7 @@
115115
nms=dict(type="nms", iou_threshold=0.7),
116116
min_bbox_size=0,
117117
),
118-
rcnn=dict(
119-
score_thr=0.05, nms=dict(type="nms", iou_threshold=0.5, max_num=100), max_per_img=100, mask_thr_binary=0.5
120-
),
118+
rcnn=dict(score_thr=0.05, nms=dict(type="nms", iou_threshold=0.5), max_per_img=100, mask_thr_binary=0.5),
121119
),
122120
)
123121

src/otx/algorithms/detection/configs/instance_segmentation/resnet50_maskrcnn/model.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,7 @@
11
"""Model configuration of Resnet50-MaskRCNN model for Instance-Seg Task."""
22

3-
# Copyright (C) 2022 Intel Corporation
4-
#
5-
# Licensed under the Apache License, Version 2.0 (the "License");
6-
# you may not use this file except in compliance with the License.
7-
# You may obtain a copy of the License at
8-
#
9-
# http://www.apache.org/licenses/LICENSE-2.0
10-
#
11-
# Unless required by applicable law or agreed to in writing,
12-
# software distributed under the License is distributed on an "AS IS" BASIS,
13-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14-
# See the License for the specific language governing permissions
15-
# and limitations under the License.
3+
# Copyright (C) 2022-2023 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
165

176
# pylint: disable=invalid-name
187

@@ -149,7 +138,7 @@
149138
),
150139
rcnn=dict(
151140
score_thr=0.05,
152-
nms=dict(type="nms", iou_threshold=0.5, max_num=100),
141+
nms=dict(type="nms", iou_threshold=0.5),
153142
max_per_img=100,
154143
mask_thr_binary=0.5,
155144
),

0 commit comments

Comments
 (0)