Skip to content

Commit 8ddb18f

Browse files
authored
Do not skip full img tile classifier + Fix Sequencial Export Issue (#2174)
* add fixes * quality checks * fix mypy issue * fix test * fix bug * fix sequencial export issue * fix export bug * add quality check * fix nncf bug * add fixes * quality checks * fix mypy issue * fix test * fix bug * fix sequencial export issue * fix export bug * add quality check * fix nncf bug * fix nncf test
1 parent b23f5a7 commit 8ddb18f

File tree

7 files changed

+34
-9
lines changed

7 files changed

+34
-9
lines changed

otx/algorithms/common/adapters/mmcv/tasks/exporter.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@ def run(self, cfg, **kwargs): # noqa: C901
3434
pipeline = dataset.get("pipeline", [])
3535
pipeline += cfg.data.test.get("pipeline", [])
3636
cfg.data.test.pipeline = pipeline
37+
for pipeline in cfg.data.test.pipeline:
38+
if pipeline.get("transforms", None):
39+
transforms = pipeline.transforms
40+
for transform in transforms:
41+
if transform.type == "Collect":
42+
for collect_key in transform["keys"]:
43+
if collect_key != "img":
44+
transform["keys"].remove(collect_key)
3745

3846
model_builder = kwargs.get("model_builder")
3947
try:

otx/algorithms/common/adapters/mmdeploy/apis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def extract_partition(
254254
model_name: str = "model",
255255
):
256256
"""Function for extracting partition."""
257-
257+
reset_mark_function_count()
258258
model_onnx = MMdeployExporter.torch2onnx(
259259
output_dir,
260260
input_data,

otx/algorithms/detection/adapters/mmdet/datasets/tiling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def gen_single_img(self, result: Dict, dataset_idx: int) -> Dict:
141141
Returns:
142142
Dict: annotation with some other useful information for data pipeline.
143143
"""
144+
result["full_res_image"] = True
144145
result["tile_box"] = (0, 0, result["img_shape"][1], result["img_shape"][0])
145146
result["dataset_idx"] = dataset_idx
146147
result["original_shape_"] = result["img_shape"]
@@ -188,6 +189,7 @@ def gen_tiles_single_img(self, result: Dict, dataset_idx: int) -> List[Dict]:
188189
y_1 = loc_i
189190
y_2 = min(loc_i + self.tile_size, height)
190191
tile = copy.deepcopy(_tile)
192+
tile["full_res_image"] = False
191193
tile["original_shape_"] = img_shape
192194
tile["ori_shape"] = (y_2 - y_1, x_2 - x_1, 3)
193195
tile["img_shape"] = tile["ori_shape"]

otx/algorithms/detection/adapters/mmdet/models/detectors/custom_maskrcnn_tile_optimized.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def make_fake_results(num_classes):
160160
mask_results.append([])
161161
return bbox_results, mask_results
162162

163-
def simple_test(self, img, img_metas, proposals=None, rescale=False):
163+
def simple_test(self, img, img_metas, proposals=None, rescale=False, full_res_image=False):
164164
"""Simple test.
165165
166166
Tile classifier is used to filter out images without any objects.
@@ -171,12 +171,15 @@ def simple_test(self, img, img_metas, proposals=None, rescale=False):
171171
img_metas (list): image meta data
172172
proposals (list, optional): proposals. Defaults to None.
173173
rescale (bool, optional): rescale. Defaults to False.
174+
full_res_image (bool, optional): if the image is full resolution or not. Defaults to False.
174175
175176
Returns:
176177
tuple: MaskRCNN output
177178
"""
178-
179179
keep = self.tile_classifier.simple_test(img) > 0.45
180+
if isinstance(full_res_image, bool):
181+
full_res_image = [full_res_image]
182+
keep = full_res_image[0] | keep
180183

181184
results = []
182185
for _ in range(len(img)):

otx/algorithms/detection/adapters/mmdet/nncf/builder.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,13 @@ def build_nncf_detector( # pylint: disable=too-many-locals,too-many-statements
136136
for pipeline in config.data.test.pipeline:
137137
if not pipeline.type.startswith("LoadImage"):
138138
test_pipeline.append(pipeline)
139+
if pipeline.get("transforms", None):
140+
transforms = pipeline.transforms
141+
for transform in transforms:
142+
if transform.type == "Collect":
143+
for collect_key in transform["keys"]:
144+
if collect_key != "img":
145+
transform["keys"].remove(collect_key)
139146

140147
test_pipeline = Compose(test_pipeline)
141148
get_fake_input_fn = partial(

otx/algorithms/detection/adapters/mmdet/utils/config_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,12 @@ def patch_tiling(config, hparams, dataset=None):
342342
logger.info(f"Patch model from: {config.model.type} to CustomMaskRCNNTileOptimized")
343343
config.model.type = "CustomMaskRCNNTileOptimized"
344344

345+
for subset in ("val", "test"):
346+
if "transforms" in config.data[subset].pipeline[0]:
347+
transforms = config.data[subset].pipeline[0]["transforms"]
348+
if transforms[-1]["type"] == "Collect":
349+
transforms[-1]["keys"].append("full_res_image")
350+
345351
if config.model.backbone.type == "efficientnet_b2b":
346352
learning_rate = 0.002
347353
logger.info(

otx/api/configuration/helper/convert.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
from enum import Enum
11-
from typing import Type, TypeVar
11+
from typing import Type, TypeVar, Dict, Any, Union
1212

1313
import yaml
1414
from omegaconf import DictConfig, OmegaConf
@@ -98,7 +98,7 @@ def convert(
9898
enum_to_str: bool = False,
9999
id_to_str: bool = False,
100100
values_only: bool = False,
101-
) -> ConvertTypeVar:
101+
) -> Any:
102102
"""Convert a configuration object to either a yaml string, a dictionary or an OmegaConf DictConfig object.
103103
104104
Args:
@@ -129,11 +129,10 @@ def convert(
129129
config_dict["id"] = str(config_id) if config_id is not None else None
130130

131131
if target == str:
132-
result = yaml.dump(config_dict)
132+
return yaml.dump(config_dict)
133133
elif target == dict:
134-
result = config_dict # type: ignore
134+
return config_dict
135135
elif target == DictConfig:
136-
result = OmegaConf.create(config_dict)
136+
return OmegaConf.create(config_dict)
137137
else:
138138
raise ValueError("Unsupported conversion target! Supported target types are [str, dict, DictConfig]")
139-
return result # type: ignore

0 commit comments

Comments
 (0)