1313
1414from iai_core .entities .graph import MultiDiGraph
1515from iai_core .entities .label import Label
16- from iai_core .entities .scored_label import ScoredLabel
1716from iai_core .utils .uid_generator import generate_uid
1817
1918from geti_types import ID , PersistentEntity
@@ -73,15 +72,10 @@ def __init__(
7372 ):
7473 self .id_ = ID (ObjectId ()) if id is None else id
7574
76- self .labels = sorted (labels , key = natural_sort_label_id )
75+ self .labels = list (labels )
7776 self .name = name
7877 self .group_type = group_type
7978
80- @property
81- def minimum_label_id (self ) -> ID :
82- """Returns the minimum (oldest) label ID, which is the first label in self.labels since this list is sorted."""
83- return self .labels [0 ].id_
84-
8579 def remove_label (self , label : Label ) -> None :
8680 """Remove label from label group if it exists in the group.
8781
@@ -103,7 +97,7 @@ def __eq__(self, other: object):
10397 """Returns True if the LabelGroup is equal to the other object."""
10498 if not isinstance (other , LabelGroup ):
10599 return False
106- return self .id_ == other .id_ and ( set ( self . labels ) == set ( other . labels ) and self . group_type == other . group_type )
100+ return self .id_ == other .id_
107101
108102 def __repr__ (self ) -> str :
109103 """Returns the string representation of the LabelGroup."""
@@ -119,8 +113,6 @@ class LabelTree(MultiDiGraph):
119113 def __init__ (self ) -> None :
120114 super ().__init__ ()
121115
122- self .__topological_order_cache : list [Label ] | None = None
123-
124116 def add_edge (self , node1 : Label , node2 : Label , edge_value : Any = None ) -> None :
125117 """Add edge between two nodes in the tree.
126118
@@ -129,50 +121,24 @@ def add_edge(self, node1: Label, node2: Label, edge_value: Any = None) -> None:
129121 :param edge_value: The value of the new edge. Defaults to None.
130122 """
131123 super ().add_edge (node1 , node2 , edge_value )
132- self .clear_topological_cache ()
133124
134125 def add_node (self , node : Label ) -> None :
135126 """Add node to the tree."""
136127 super ().add_node (node )
137- self .clear_topological_cache ()
138128
139129 def add_edges (self , edges : Any ) -> None :
140130 """Add edges between Labels."""
141131 self ._graph .add_edges_from (edges )
142- self .clear_topological_cache ()
143132
144133 def remove_node (self , node : Label ) -> None :
145134 """Remove node from the tree."""
146135 super ().remove_node (node )
147- self .clear_topological_cache ()
148136
149137 @property
150138 def num_labels (self ) -> int :
151139 """Return the number of labels in the tree."""
152140 return self .num_nodes ()
153141
154- def clear_topological_cache (self ) -> None :
155- """Clear the internal cache of the list of labels sorted in topological order.
156-
157- This function should be called if the topology of the graph has changed to
158- prevent the cache from being stale.
159- Note that it is automatically called when modifying the topology through the
160- methods provided by this class.
161- """
162- self .__topological_order_cache = None
163-
164- def get_labels_in_topological_order (self ) -> list [Label ]:
165- """Return a list of the labels in this graph sorted in topological order.
166-
167- To avoid performance issues, the output of this function is cached.
168- """
169- if self .__topological_order_cache is None :
170- # TODO: It seems that we are storing the edges the wrong way around.
171- # To work around this issue, we have to reverse the sorted list.
172- self .__topological_order_cache = list (reversed (list (self .topological_sort ())))
173-
174- return self .__topological_order_cache
175-
176142 @property
177143 def type (self ) -> str :
178144 """Returns the type of the LabelTree."""
@@ -181,7 +147,6 @@ def type(self) -> str:
181147 def add_child (self , parent : Label , child : Label ) -> None :
182148 """Add a `child` Label to `parent`."""
183149 self .add_edge (child , parent )
184- self .clear_topological_cache ()
185150
186151 def get_parent (self , label : Label ) -> Label | None :
187152 """Returns the parent of `label`"""
@@ -322,12 +287,12 @@ def get_labels(self, include_empty: bool) -> list[Label]:
322287 :param include_empty: flag determining whether to include empty labels
323288 :return: list of labels in the label schema
324289 """
325- labels = []
326- for group in self . _groups :
327- for label in group . labels :
328- if ( include_empty or not label . is_empty ) and label . id_ not in self . deleted_label_ids :
329- labels . append ( label )
330- return sorted ( labels , key = lambda x : x . id_ )
290+ return [
291+ label
292+ for group in self . _groups
293+ for label in group . labels
294+ if ( include_empty or not label . is_empty ) and label . id_ not in self . deleted_label_ids
295+ ]
331296
332297 def get_label_map (self ) -> dict [ID , Label ]:
333298 """
@@ -345,7 +310,7 @@ def get_empty_labels(self) -> tuple[Label, ...]:
345310
346311 :return: tuple of empty labels in the label schema
347312 """
348- return tuple (sorted ([ label for label in self .get_labels (include_empty = True ) if label .is_empty ]) )
313+ return tuple (label for label in self .get_labels (include_empty = True ) if label .is_empty )
349314
350315 def get_label_ids (self , include_empty : bool ) -> list [ID ]:
351316 """
@@ -364,9 +329,7 @@ def get_all_labels(self) -> list[Label]:
364329
365330 :return: list of labels in the label schema
366331 """
367- labels = [label for group in self ._groups for label in group .labels ]
368-
369- return sorted (labels , key = lambda x : x .id_ )
332+ return [label for group in self ._groups for label in group .labels ]
370333
371334 def get_groups (self , include_empty : bool = False ) -> list [LabelGroup ]:
372335 """
@@ -594,14 +557,7 @@ def __repr__(self) -> str:
594557 def __eq__ (self , other : object ) -> bool :
595558 if not isinstance (other , LabelSchema ):
596559 return False
597- return (
598- self .id_ == other .id_
599- and self .project_id == other .project_id
600- and self .previous_schema_revision_id == other .previous_schema_revision_id
601- and self .label_tree == other .label_tree
602- and self .get_groups (include_empty = True ) == other .get_groups (include_empty = True )
603- and self .deleted_label_ids == other .deleted_label_ids
604- )
560+ return self .id_ == other .id_
605561
606562
607563class NullLabelSchema (LabelSchema ):
@@ -676,7 +632,7 @@ def from_parent(
676632 label_groups = []
677633
678634 for parent_group in parent_schema .get_groups (include_empty = True ):
679- group_labels = list ( set ( parent_group .labels ). intersection ( set_of_labels ))
635+ group_labels = [ label for label in parent_group .labels if label in set_of_labels ]
680636 if len (group_labels ) > 0 :
681637 label_groups .append (
682638 LabelGroup (
@@ -737,32 +693,4 @@ def __repr__(self) -> str:
737693 def __eq__ (self , other : object ) -> bool :
738694 if not isinstance (other , LabelSchemaView ):
739695 return False
740- return (
741- self .parent_schema == other .parent_schema
742- and self .task_node_id == other .task_node_id
743- and self .previous_schema_revision_id == other .previous_schema_revision_id
744- and self .label_tree == other .label_tree
745- and self .get_groups (include_empty = True ) == other .get_groups (include_empty = True )
746- )
747-
748-
749- def natural_sort_label_id (target : ID | Label | ScoredLabel ) -> list [int | str ]:
750- """Generates a natural sort key for a Label object based on its ID.
751-
752- Example:
753- origin_sorted_labels = sorted(labels, key=lambda x: x.id_)
754- natural_sorted_labels = sorted(labels, key=lambda x: x.natural_sort_label_id)
755-
756- print(origin_sorted_labels) # Output: [Label(0), Label(1), Label(10), ... Label(2)]
757- print(natural_sorted_labels) # Output: [Label(0), Label(1), Label(2), ... Label(10)]
758-
759- :param target (Union[ID, Label]): The ID or Label or ScoredLabel object to be sorted.
760- :returns: List[Union[int, str]]: A list of integers representing the numeric substrings in the ID
761- in the order they appear.
762- """
763-
764- if isinstance (target , Label | ScoredLabel ):
765- target = target .id_
766- if isinstance (target , str ) and target .isdecimal ():
767- return ["" , int (target )] # "" is added for the case where id of some lables is None
768- return [target ]
696+ return self .id_ == other .id_
0 commit comments