Skip to content

Commit a2d2c81

Browse files
authored
Fix incorrect all_groups order configuration in HLabelInfo (#4067)
* Fix all_labels * Update CHAGELOG * label_groups change
1 parent 8bba44c commit a2d2c81

File tree

3 files changed

+16
-11
lines changed

3 files changed

+16
-11
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ All notable changes to this project will be documented in this file.
6565
(<https://github.com/openvinotoolkit/training_extensions/pull/4018>)
6666
- Update HPO interface
6767
(<https://github.com/openvinotoolkit/training_extensions/pull/4035>)
68+
- Bump onnx to 1.17.0 to omit CVE-2024-5187
69+
(<https://github.com/openvinotoolkit/training_extensions/pull/4063>)
6870

6971
### Bug fixes
7072

@@ -104,6 +106,8 @@ All notable changes to this project will be documented in this file.
104106
(<https://github.com/openvinotoolkit/training_extensions/pull/4052>)
105107
- Fix applying model's hparams when loading model from checkpoint
106108
(<https://github.com/openvinotoolkit/training_extensions/pull/4057>)
109+
- Fix incorrect all_groups order configuration in HLabelInfo
110+
(<https://github.com/openvinotoolkit/training_extensions/pull/4067>)
107111

108112
## \[v2.1.0\]
109113

src/otx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Copyright (C) 2024 Intel Corporation
44
# SPDX-License-Identifier: Apache-2.0
55

6-
__version__ = "2.2.0rc11"
6+
__version__ = "2.2.0rc12"
77

88
import os
99
from pathlib import Path

src/otx/core/types/label.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,8 @@ def from_dm_label_groups(cls, dm_label_categories: LabelCategories) -> HLabelInf
169169
dm_label_categories (LabelCategories): the label categories of datumaro.
170170
"""
171171

172-
def get_exclusive_group_info(all_groups: list[Label | list[Label]]) -> dict[str, Any]:
172+
def get_exclusive_group_info(exclusive_groups: list[Label | list[Label]]) -> dict[str, Any]:
173173
"""Get exclusive group information."""
174-
exclusive_groups = [g for g in all_groups if len(g) > 1]
175-
176174
last_logits_pos = 0
177175
num_single_label_classes = 0
178176
head_idx_to_logits_range = {}
@@ -193,12 +191,10 @@ def get_exclusive_group_info(all_groups: list[Label | list[Label]]) -> dict[str,
193191
}
194192

195193
def get_single_label_group_info(
196-
all_groups: list[Label | list[Label]],
194+
single_label_groups: list[Label | list[Label]],
197195
num_exclusive_groups: int,
198196
) -> dict[str, Any]:
199197
"""Get single label group information."""
200-
single_label_groups = [g for g in all_groups if len(g) == 1]
201-
202198
class_to_idx = {}
203199

204200
for i, group in enumerate(single_label_groups):
@@ -256,8 +252,13 @@ def convert_labels_if_needed(
256252
label_names = [item.name for item in dm_label_categories.items]
257253
all_groups = convert_labels_if_needed(dm_label_categories, label_names)
258254

259-
exclusive_group_info = get_exclusive_group_info(all_groups)
260-
single_label_group_info = get_single_label_group_info(all_groups, exclusive_group_info["num_multiclass_heads"])
255+
exclusive_groups = [g for g in all_groups if len(g) > 1]
256+
exclusive_group_info = get_exclusive_group_info(exclusive_groups)
257+
single_label_groups = [g for g in all_groups if len(g) == 1]
258+
single_label_group_info = get_single_label_group_info(
259+
single_label_groups,
260+
exclusive_group_info["num_multiclass_heads"],
261+
)
261262

262263
merged_class_to_idx = merge_class_to_idx(
263264
exclusive_group_info["class_to_idx"],
@@ -268,13 +269,13 @@ def convert_labels_if_needed(
268269

269270
return HLabelInfo(
270271
label_names=label_names,
271-
label_groups=all_groups,
272+
label_groups=exclusive_groups + single_label_groups,
272273
num_multiclass_heads=exclusive_group_info["num_multiclass_heads"],
273274
num_multilabel_classes=single_label_group_info["num_multilabel_classes"],
274275
head_idx_to_logits_range=exclusive_group_info["head_idx_to_logits_range"],
275276
num_single_label_classes=exclusive_group_info["num_single_label_classes"],
276277
class_to_group_idx=merged_class_to_idx,
277-
all_groups=all_groups,
278+
all_groups=exclusive_groups + single_label_groups,
278279
label_to_idx=label_to_idx,
279280
label_tree_edges=get_label_tree_edges(dm_label_categories.items),
280281
empty_multiclass_head_indices=[], # consider the label removing case

0 commit comments

Comments
 (0)