Skip to content

Commit 96fbcfa

Browse files
authored
ITEP-67565: Allow users to sort labels in any order they want (#239)
1 parent 26ac753 commit 96fbcfa

File tree

9 files changed

+74
-262
lines changed

9 files changed

+74
-262
lines changed

interactive_ai/libs/iai_core_py/iai_core/entities/graph.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,3 @@ class MultiDiGraph(Graph):
129129

130130
def __init__(self) -> None:
131131
super().__init__(directed=True)
132-
133-
def topological_sort(self):
134-
"""Returns a generator of nodes in topologically sorted order."""
135-
return nx.topological_sort(self._graph)

interactive_ai/libs/iai_core_py/iai_core/entities/label.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -145,14 +145,7 @@ def __repr__(self):
145145
def __eq__(self, other: object) -> bool:
146146
"""Returns True if the two labels are equal."""
147147
if isinstance(other, Label):
148-
return (
149-
self.id_ == other.id_
150-
and self.name == other.name
151-
and self.color == other.color
152-
and self.hotkey == other.hotkey
153-
and self.domain == other.domain
154-
and self.is_anomalous == other.is_anomalous
155-
)
148+
return self.id_ == other.id_
156149
return False
157150

158151
def __lt__(self, other: object):

interactive_ai/libs/iai_core_py/iai_core/entities/label_schema.py

Lines changed: 13 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
from iai_core.entities.graph import MultiDiGraph
1515
from iai_core.entities.label import Label
16-
from iai_core.entities.scored_label import ScoredLabel
1716
from iai_core.utils.uid_generator import generate_uid
1817

1918
from geti_types import ID, PersistentEntity
@@ -73,15 +72,10 @@ def __init__(
7372
):
7473
self.id_ = ID(ObjectId()) if id is None else id
7574

76-
self.labels = sorted(labels, key=natural_sort_label_id)
75+
self.labels = list(labels)
7776
self.name = name
7877
self.group_type = group_type
7978

80-
@property
81-
def minimum_label_id(self) -> ID:
82-
"""Returns the minimum (oldest) label ID, which is the first label in self.labels since this list is sorted."""
83-
return self.labels[0].id_
84-
8579
def remove_label(self, label: Label) -> None:
8680
"""Remove label from label group if it exists in the group.
8781
@@ -103,7 +97,7 @@ def __eq__(self, other: object):
10397
"""Returns True if the LabelGroup is equal to the other object."""
10498
if not isinstance(other, LabelGroup):
10599
return False
106-
return self.id_ == other.id_ and (set(self.labels) == set(other.labels) and self.group_type == other.group_type)
100+
return self.id_ == other.id_
107101

108102
def __repr__(self) -> str:
109103
"""Returns the string representation of the LabelGroup."""
@@ -119,8 +113,6 @@ class LabelTree(MultiDiGraph):
119113
def __init__(self) -> None:
120114
super().__init__()
121115

122-
self.__topological_order_cache: list[Label] | None = None
123-
124116
def add_edge(self, node1: Label, node2: Label, edge_value: Any = None) -> None:
125117
"""Add edge between two nodes in the tree.
126118
@@ -129,50 +121,24 @@ def add_edge(self, node1: Label, node2: Label, edge_value: Any = None) -> None:
129121
:param edge_value: The value of the new edge. Defaults to None.
130122
"""
131123
super().add_edge(node1, node2, edge_value)
132-
self.clear_topological_cache()
133124

134125
def add_node(self, node: Label) -> None:
135126
"""Add node to the tree."""
136127
super().add_node(node)
137-
self.clear_topological_cache()
138128

139129
def add_edges(self, edges: Any) -> None:
140130
"""Add edges between Labels."""
141131
self._graph.add_edges_from(edges)
142-
self.clear_topological_cache()
143132

144133
def remove_node(self, node: Label) -> None:
145134
"""Remove node from the tree."""
146135
super().remove_node(node)
147-
self.clear_topological_cache()
148136

149137
@property
150138
def num_labels(self) -> int:
151139
"""Return the number of labels in the tree."""
152140
return self.num_nodes()
153141

154-
def clear_topological_cache(self) -> None:
155-
"""Clear the internal cache of the list of labels sorted in topological order.
156-
157-
This function should be called if the topology of the graph has changed to
158-
prevent the cache from being stale.
159-
Note that it is automatically called when modifying the topology through the
160-
methods provided by this class.
161-
"""
162-
self.__topological_order_cache = None
163-
164-
def get_labels_in_topological_order(self) -> list[Label]:
165-
"""Return a list of the labels in this graph sorted in topological order.
166-
167-
To avoid performance issues, the output of this function is cached.
168-
"""
169-
if self.__topological_order_cache is None:
170-
# TODO: It seems that we are storing the edges the wrong way around.
171-
# To work around this issue, we have to reverse the sorted list.
172-
self.__topological_order_cache = list(reversed(list(self.topological_sort())))
173-
174-
return self.__topological_order_cache
175-
176142
@property
177143
def type(self) -> str:
178144
"""Returns the type of the LabelTree."""
@@ -181,7 +147,6 @@ def type(self) -> str:
181147
def add_child(self, parent: Label, child: Label) -> None:
182148
"""Add a `child` Label to `parent`."""
183149
self.add_edge(child, parent)
184-
self.clear_topological_cache()
185150

186151
def get_parent(self, label: Label) -> Label | None:
187152
"""Returns the parent of `label`"""
@@ -322,12 +287,12 @@ def get_labels(self, include_empty: bool) -> list[Label]:
322287
:param include_empty: flag determining whether to include empty labels
323288
:return: list of labels in the label schema
324289
"""
325-
labels = []
326-
for group in self._groups:
327-
for label in group.labels:
328-
if (include_empty or not label.is_empty) and label.id_ not in self.deleted_label_ids:
329-
labels.append(label)
330-
return sorted(labels, key=lambda x: x.id_)
290+
return [
291+
label
292+
for group in self._groups
293+
for label in group.labels
294+
if (include_empty or not label.is_empty) and label.id_ not in self.deleted_label_ids
295+
]
331296

332297
def get_label_map(self) -> dict[ID, Label]:
333298
"""
@@ -345,7 +310,7 @@ def get_empty_labels(self) -> tuple[Label, ...]:
345310
346311
:return: tuple of empty labels in the label schema
347312
"""
348-
return tuple(sorted([label for label in self.get_labels(include_empty=True) if label.is_empty]))
313+
return tuple(label for label in self.get_labels(include_empty=True) if label.is_empty)
349314

350315
def get_label_ids(self, include_empty: bool) -> list[ID]:
351316
"""
@@ -364,9 +329,7 @@ def get_all_labels(self) -> list[Label]:
364329
365330
:return: list of labels in the label schema
366331
"""
367-
labels = [label for group in self._groups for label in group.labels]
368-
369-
return sorted(labels, key=lambda x: x.id_)
332+
return [label for group in self._groups for label in group.labels]
370333

371334
def get_groups(self, include_empty: bool = False) -> list[LabelGroup]:
372335
"""
@@ -594,14 +557,7 @@ def __repr__(self) -> str:
594557
def __eq__(self, other: object) -> bool:
595558
if not isinstance(other, LabelSchema):
596559
return False
597-
return (
598-
self.id_ == other.id_
599-
and self.project_id == other.project_id
600-
and self.previous_schema_revision_id == other.previous_schema_revision_id
601-
and self.label_tree == other.label_tree
602-
and self.get_groups(include_empty=True) == other.get_groups(include_empty=True)
603-
and self.deleted_label_ids == other.deleted_label_ids
604-
)
560+
return self.id_ == other.id_
605561

606562

607563
class NullLabelSchema(LabelSchema):
@@ -676,7 +632,7 @@ def from_parent(
676632
label_groups = []
677633

678634
for parent_group in parent_schema.get_groups(include_empty=True):
679-
group_labels = list(set(parent_group.labels).intersection(set_of_labels))
635+
group_labels = [label for label in parent_group.labels if label in set_of_labels]
680636
if len(group_labels) > 0:
681637
label_groups.append(
682638
LabelGroup(
@@ -737,32 +693,4 @@ def __repr__(self) -> str:
737693
def __eq__(self, other: object) -> bool:
738694
if not isinstance(other, LabelSchemaView):
739695
return False
740-
return (
741-
self.parent_schema == other.parent_schema
742-
and self.task_node_id == other.task_node_id
743-
and self.previous_schema_revision_id == other.previous_schema_revision_id
744-
and self.label_tree == other.label_tree
745-
and self.get_groups(include_empty=True) == other.get_groups(include_empty=True)
746-
)
747-
748-
749-
def natural_sort_label_id(target: ID | Label | ScoredLabel) -> list[int | str]:
750-
"""Generates a natural sort key for a Label object based on its ID.
751-
752-
Example:
753-
origin_sorted_labels = sorted(labels, key=lambda x: x.id_)
754-
natural_sorted_labels = sorted(labels, key=lambda x: x.natural_sort_label_id)
755-
756-
print(origin_sorted_labels) # Output: [Label(0), Label(1), Label(10), ... Label(2)]
757-
print(natural_sorted_labels) # Output: [Label(0), Label(1), Label(2), ... Label(10)]
758-
759-
:param target (Union[ID, Label]): The ID or Label or ScoredLabel object to be sorted.
760-
:returns: List[Union[int, str]]: A list of integers representing the numeric substrings in the ID
761-
in the order they appear.
762-
"""
763-
764-
if isinstance(target, Label | ScoredLabel):
765-
target = target.id_
766-
if isinstance(target, str) and target.isdecimal():
767-
return ["", int(target)] # "" is added for the case where id of some lables is None
768-
return [target]
696+
return self.id_ == other.id_

interactive_ai/libs/iai_core_py/iai_core/repos/mappers/mongodb_mappers/label_mapper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,13 @@ def backward(instance: dict, project_identifier: ProjectIdentifier) -> LabelGrou
9191
from iai_core.repos import LabelRepo
9292

9393
label_repo = LabelRepo(project_identifier)
94+
label_ids = [IDToMongo.backward(label_id) for label_id in instance["label_ids"]]
95+
label_map = label_repo.get_by_ids(label_ids)
9496
return LabelGroup(
9597
id=IDToMongo.backward(instance["_id"]),
9698
name=instance["name"],
9799
group_type=LabelGroupType[instance["relation_type"]],
98-
labels=list(
99-
label_repo.get_by_ids([IDToMongo.backward(label_id) for label_id in instance["label_ids"]]).values()
100-
),
100+
labels=[label_map[label_id] for label_id in label_ids],
101101
)
102102

103103

interactive_ai/libs/iai_core_py/iai_core/utils/project_builder.py

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -985,38 +985,65 @@ def edit_existing_project( # noqa: C901, PLR0912, PLR0915
985985
if label_parent:
986986
child_to_parent_id[label] = label_parent
987987
group_name_by_label[label.id_] = label_group
988-
labels_by_task[task_node.id_] = updated_labels
988+
989+
# Sort labels according to ordering in request
990+
ordered_label_names = parser.get_custom_labels_names_by_task(task_name=task_name)
991+
ordered_updated_labels = []
992+
for label_name in ordered_label_names:
993+
for label in updated_labels:
994+
if label_name == label.name:
995+
ordered_updated_labels.append(label)
996+
997+
# Add empty label, if it exists in the old labels, but not in the updated ones
998+
if not any(label.is_empty for label in ordered_updated_labels):
999+
empty_label = next((label for label in old_labels if label.is_empty), None)
1000+
if empty_label is not None:
1001+
ordered_updated_labels.append(empty_label)
1002+
empty_group = next(group for group in old_groups if group.group_type == LabelGroupType.EMPTY_LABEL)
1003+
group_name_by_label[empty_label.id_] = empty_group.name
1004+
1005+
labels_by_task[task_node.id_] = ordered_updated_labels
9891006

9901007
if is_keypoint_detection_enabled and task_node.task_properties.task_type == TaskType.KEYPOINT_DETECTION:
9911008
keypoint_structure_data = parser.get_keypoint_structure_data(task_name=task_name)
9921009
keypoint_structure = cls._build_keypoint_structure(
9931010
keypoint_structure_data=keypoint_structure_data,
994-
labels=updated_labels,
1011+
labels=ordered_updated_labels,
9951012
)
9961013
project.keypoint_structure = keypoint_structure
9971014

998-
if labels_structure_changed:
999-
# Labels structure changed, so the label_schema and task_to_label_schema_view are recomputed from scratch.
1000-
(
1001-
label_schema,
1002-
task_to_label_schema_view,
1003-
) = ProjectBuilder._build_label_schema( # type: ignore[assignment]
1004-
project=project,
1005-
labels_by_task=labels_by_task,
1006-
group_name_by_label_id=group_name_by_label,
1007-
child_to_parent_id=child_to_parent_id, # type: ignore
1008-
previous_schema_revision_id=label_schema.id_,
1009-
previous_task_id_to_schema_revision_id=task_id_to_schema_revision_id,
1010-
)
1011-
label_schema_repo.save(instance=label_schema)
1012-
for _, task_label_schema in task_to_label_schema_view.items():
1013-
label_schema_repo.save(instance=task_label_schema)
1015+
(
1016+
new_label_schema,
1017+
new_task_to_label_schema_view,
1018+
) = ProjectBuilder._build_label_schema( # type: ignore[assignment]
1019+
project=project,
1020+
labels_by_task=labels_by_task,
1021+
group_name_by_label_id=group_name_by_label,
1022+
child_to_parent_id=child_to_parent_id, # type: ignore
1023+
previous_schema_revision_id=label_schema.id_,
1024+
previous_task_id_to_schema_revision_id=task_id_to_schema_revision_id,
1025+
)
1026+
1027+
# Reset label schema (view) ids if no structural change was made to the label topology
1028+
# Note that we always have to regenerate and save the schemas to ensure the order of the labels is updated
1029+
if not labels_structure_changed:
1030+
new_label_schema.id_ = label_schema.id_
1031+
new_label_schema.previous_schema_revision_id = label_schema.previous_schema_revision_id
1032+
for task_title, new_task_label_schema_view in new_task_to_label_schema_view.items():
1033+
new_task_label_schema_view.id_ = task_to_label_schema_view[task_title].id_
1034+
new_task_label_schema_view.previous_schema_revision_id = task_to_label_schema_view[
1035+
task_title
1036+
].previous_schema_revision_id
1037+
1038+
label_schema_repo.save(instance=new_label_schema)
1039+
for new_task_label_schema in new_task_to_label_schema_view.values():
1040+
label_schema_repo.save(instance=new_task_label_schema)
10141041

10151042
project_repo.save(instance=project)
10161043
return (
10171044
project,
1018-
label_schema,
1019-
task_to_label_schema_view,
1045+
new_label_schema,
1046+
new_task_to_label_schema_view,
10201047
tuple(labels_to_revisit),
10211048
modified_scene_ids_by_storage,
10221049
labels_structure_changed,

interactive_ai/libs/iai_core_py/tests/entities/test_graph.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -851,19 +851,3 @@ def test_multi_di_graph(self):
851851
},
852852
]
853853
)
854-
855-
def test_multi_di_graph_topological_sort(self):
856-
"""
857-
<b>Description:</b>
858-
Check topological_sort method of MultiDiGraph class object
859-
860-
<b>Input data:</b>
861-
MultiDiGraph objects with specified "edges" parameter
862-
863-
<b>Expected results:</b>
864-
Test passes if topological_sort method returns generator object with expected values
865-
"""
866-
multi_di_graph = self.multi_di_graph()
867-
topological_sort = multi_di_graph.topological_sort()
868-
for expected_value in [(1, 1), (1, 2), (3, 1), (2, 1)]:
869-
assert next(topological_sort) == expected_value

0 commit comments

Comments
 (0)