Skip to content

Commit ac8a7dd

Browse files
authored
Update ModelAPI configuration (#2564)
* Update MAPI rt infor for detection * Upadte export info for cls, det and seg * Update unit tests
1 parent a576962 commit ac8a7dd

File tree

5 files changed

+29
-7
lines changed

5 files changed

+29
-7
lines changed

src/otx/algorithms/classification/utils/cls_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,16 +98,20 @@ def get_cls_model_api_configuration(label_schema: LabelSchemaEntity, inference_c
9898
"""Get ModelAPI config."""
9999
mapi_config = {}
100100
mapi_config[("model_info", "model_type")] = "Classification"
101+
mapi_config[("model_info", "task_type")] = "classification"
101102
mapi_config[("model_info", "confidence_threshold")] = str(inference_config["confidence_threshold"])
102103
mapi_config[("model_info", "multilabel")] = str(inference_config["multilabel"])
103104
mapi_config[("model_info", "hierarchical")] = str(inference_config["hierarchical"])
104105
mapi_config[("model_info", "output_raw_scores")] = str(True)
105106

106107
all_labels = ""
108+
all_label_ids = ""
107109
for lbl in label_schema.get_labels(include_empty=False):
108110
all_labels += lbl.name.replace(" ", "_") + " "
109-
all_labels = all_labels.strip()
110-
mapi_config[("model_info", "labels")] = all_labels
111+
all_label_ids += f"{lbl.id_} "
112+
113+
mapi_config[("model_info", "labels")] = all_labels.strip()
114+
mapi_config[("model_info", "label_ids")] = all_label_ids.strip()
111115

112116
hierarchical_config = {}
113117
hierarchical_config["cls_heads_info"] = get_multihead_class_info(label_schema)

src/otx/algorithms/detection/utils/utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,16 +110,22 @@ def get_det_model_api_configuration(
110110
"""Get ModelAPI config."""
111111
omz_config = {}
112112
all_labels = ""
113+
all_label_ids = ""
113114
if task_type == TaskType.DETECTION:
114115
omz_config[("model_info", "model_type")] = "ssd"
116+
omz_config[("model_info", "task_type")] = "detection"
115117
if task_type == TaskType.INSTANCE_SEGMENTATION:
116118
omz_config[("model_info", "model_type")] = "MaskRCNN"
119+
omz_config[("model_info", "task_type")] = "instance_segmentation"
117120
all_labels = "otx_empty_lbl "
121+
all_label_ids = "None "
118122
if tiling_parameters.enable_tiling:
119123
omz_config[("model_info", "resize_type")] = "fit_to_window_letterbox"
120124
if task_type == TaskType.ROTATED_DETECTION:
121-
omz_config[("model_info", "model_type")] = "rotated_detection"
125+
omz_config[("model_info", "model_type")] = "MaskRCNN"
126+
omz_config[("model_info", "task_type")] = "rotated_detection"
122127
all_labels = "otx_empty_lbl "
128+
all_label_ids = "None "
123129
if tiling_parameters.enable_tiling:
124130
omz_config[("model_info", "resize_type")] = "fit_to_window_letterbox"
125131

@@ -137,9 +143,10 @@ def get_det_model_api_configuration(
137143

138144
for lbl in label_schema.get_labels(include_empty=False):
139145
all_labels += lbl.name.replace(" ", "_") + " "
140-
all_labels = all_labels.strip()
146+
all_label_ids += f"{lbl.id_} "
141147

142-
omz_config[("model_info", "labels")] = all_labels
148+
omz_config[("model_info", "labels")] = all_labels.strip()
149+
omz_config[("model_info", "label_ids")] = all_label_ids.strip()
143150

144151
return omz_config
145152

src/otx/algorithms/segmentation/utils/metadata.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,16 @@
1212
def get_seg_model_api_configuration(label_schema: LabelSchemaEntity, hyperparams: ConfigDict):
1313
"""Get ModelAPI config."""
1414
all_labels = ""
15+
all_label_ids = ""
1516
for lbl in label_schema.get_labels(include_empty=False):
1617
all_labels += lbl.name.replace(" ", "_") + " "
17-
all_labels = all_labels.strip()
18+
all_label_ids += f"{lbl.id_} "
1819

1920
return {
2021
("model_info", "model_type"): "Segmentation",
2122
("model_info", "soft_threshold"): str(hyperparams.postprocessing.soft_threshold),
2223
("model_info", "blur_strength"): str(hyperparams.postprocessing.blur_strength),
23-
("model_info", "labels"): all_labels,
24+
("model_info", "labels"): all_labels.strip(),
25+
("model_info", "label_ids"): all_label_ids.strip(),
26+
("model_info", "task_type"): "segmentation",
2427
}

tests/unit/algorithms/classification/utils/test_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,7 @@ def test_get_cls_model_api_configuration(default_hierarchical_data):
9393
assert len(model_api_cfg) > 0
9494
assert model_api_cfg[("model_info", "confidence_threshold")] == str(config["confidence_threshold"])
9595
assert ("model_info", "hierarchical_config") in model_api_cfg
96+
assert ("model_info", "labels") in model_api_cfg
97+
assert ("model_info", "label_ids") in model_api_cfg
98+
assert len(label_schema.get_labels(include_empty=False)) == len(model_api_cfg[("model_info", "labels")].split())
99+
assert len(label_schema.get_labels(include_empty=False)) == len(model_api_cfg[("model_info", "label_ids")].split())

tests/unit/algorithms/detection/utils/test_detection_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,7 @@ def test_get_det_model_api_configuration():
3434
tiling_parameters.tile_overlap / tiling_parameters.tile_ir_scale_factor
3535
)
3636
assert model_api_cfg[("model_info", "max_pred_number")] == str(tiling_parameters.tile_max_number)
37+
assert ("model_info", "labels") in model_api_cfg
38+
assert ("model_info", "label_ids") in model_api_cfg
39+
assert len(label_schema.get_labels(include_empty=False)) == len(model_api_cfg[("model_info", "labels")].split())
40+
assert len(label_schema.get_labels(include_empty=False)) == len(model_api_cfg[("model_info", "label_ids")].split())

0 commit comments

Comments
 (0)