Skip to content

Commit 85bfe9f

Browse files
authored
Align label order between Geti and OTX (#2369)
* align label order * align with pre-commit * update CHANGELOG.md * deal with edge case * update type hint
1 parent 0c4be3c commit 85bfe9f

File tree

4 files changed

+30
-7
lines changed

4 files changed

+30
-7
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ All notable changes to this project will be documented in this file.
3333

3434
- Fix the bug that auto adapt batch size is unavailable with IterBasedRunner (<https://github.com/openvinotoolkit/training_extensions/pull/2182>)
3535
- Fix the bug that learning rate isn't scaled when multi-GPU trianing is enabled(<https://github.com/openvinotoolkit/training_extensions/pull/2254>)
36+
- Fix the bug that label order is misaligned when model is deployed from Geti (<https://github.com/openvinotoolkit/training_extensions/pull/2369>)
3637

3738
### Known issues
3839

src/otx/api/entities/label_schema.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121
logger = logging.getLogger(__name__)
2222

2323

24-
def natural_sort_label_id(target: Union[ID, LabelEntity, ScoredLabel]) -> List:
24+
def natural_sort_label_id(target: Union[ID, LabelEntity, ScoredLabel]) -> List[Union[int, str]]:
2525
"""Generates a natural sort key for a LabelEntity object based on its ID.
2626
2727
Args:
2828
target (Union[ID, LabelEntity]): The ID or LabelEntity or ScoredLabel object to be sorted.
2929
3030
Returns:
31-
List[int]: A list of integers representing the numeric substrings in the ID
31+
List[Union[int, str]]: A list of integers representing the numeric substrings in the ID
3232
in the order they appear.
3333
3434
Example:
@@ -41,9 +41,9 @@ def natural_sort_label_id(target: Union[ID, LabelEntity, ScoredLabel]) -> List:
4141

4242
if isinstance(target, (LabelEntity, ScoredLabel)):
4343
target = target.id_
44-
if isinstance(target, int):
45-
return [target]
46-
return [int(t) if t.isdigit() else t for t in re.split(r"(\d+)", target)]
44+
if isinstance(target, str) and target.isdecimal():
45+
return ["", int(target)] # "" is added for the case where id of some lables is None
46+
return [target]
4747

4848

4949
class LabelGroupExistsException(ValueError):

tests/unit/api/entities/test_label_schema.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#
66

77
import pytest
8-
from networkx.classes.reportviews import EdgeDataView, NodeView, OutMultiEdgeDataView
8+
from networkx.classes.reportviews import NodeView, OutMultiEdgeDataView
99

1010
from otx.api.entities.color import Color
1111
from otx.api.entities.id import ID
@@ -19,11 +19,34 @@
1919
LabelSchemaEntity,
2020
LabelTree,
2121
ScoredLabel,
22+
natural_sort_label_id,
2223
)
2324
from tests.unit.api.constants.components import OtxSdkComponent
2425
from tests.unit.api.constants.requirements import Requirements
2526

2627

28+
def get_label_entity(id_val: str):
29+
return LabelEntity(name=id_val, domain=Domain.DETECTION, id=ID(id_val))
30+
31+
32+
def get_scored_label(id_val: str):
33+
return ScoredLabel(label=get_label_entity(id_val))
34+
35+
36+
@pytest.mark.priority_medium
37+
@pytest.mark.unit
38+
@pytest.mark.reqids(Requirements.REQ_1)
39+
@pytest.mark.parametrize("id_val", ["3", "fake1name2"])
40+
@pytest.mark.parametrize("target_class", [ID, get_label_entity, get_scored_label])
41+
def test_natural_sort_label_id(id_val: str, target_class):
42+
target = target_class(id_val)
43+
44+
if id_val.isdecimal():
45+
assert natural_sort_label_id(target) == ["", int(id_val)]
46+
else:
47+
assert natural_sort_label_id(target) == [id_val]
48+
49+
2750
@pytest.mark.components(OtxSdkComponent.OTX_API)
2851
class TestLabelSchema:
2952
@pytest.mark.priority_medium

tests/unit/api/usecases/exportable_code/test_prediction_to_annotation_converter.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -750,7 +750,6 @@ def test_classification_to_annotation_init(self):
750750
)
751751
label_schema = LabelSchemaEntity(label_groups=[label_group, other_label_group])
752752
converter = ClassificationToAnnotationConverter(label_schema=label_schema)
753-
assert converter.labels == non_empty_labels + other_non_empty_labels
754753
assert not converter.empty_label
755754
assert converter.label_schema == label_schema
756755
assert converter.hierarchical

0 commit comments

Comments
 (0)