Skip to content

Commit 04948e7

Browse files
Multi-label and hierarchical-label classification support (#1925)
This pull request introduces comprehensive support for hierarchical label categories in the experimental categories module. It adds new data structures to represent hierarchical relationships between labels, updates legacy conversion logic to handle hierarchical and multi-label classification, and improves validation and compatibility. The changes also make several dataclasses mutable and update type imports for broader compatibility. **Hierarchical label support and data structures:** * Introduced new classes: `HierarchicalLabelCategory`, `LabelGroup`, and `HierarchicalLabelCategories` in `categories.py` to represent hierarchical label structures, label groups, and provide methods for hierarchy traversal and validation. These classes support parent-child relationships, groupings, and compatibility with existing interfaces. * Added extensive validation in `HierarchicalLabelCategories.__post_init__` to ensure label uniqueness, group-label consistency, and valid parent references. * Provided utility methods for hierarchy navigation (e.g., `find`, `get_children`, `get_parent`, `get_hierarchy_level`) and compatibility with legacy APIs. **Legacy dataset conversion and analysis improvements:** * Enhanced `analyze_legacy_dataset` and `convert_from_legacy` in `legacy.py` to detect hierarchical and multi-label projects, convert legacy label/group structures to the new classes, and ensure hierarchical labels are handled as lists. [[1]](diffhunk://#diff-aa35f06eaa7d35ff5ffa7077e181ed0b773549c22d42a92e78e076947f9b88f5L776-R817) [[2]](diffhunk://#diff-aa35f06eaa7d35ff5ffa7077e181ed0b773549c22d42a92e78e076947f9b88f5L803-R900) [[3]](diffhunk://#diff-aa35f06eaa7d35ff5ffa7077e181ed0b773549c22d42a92e78e076947f9b88f5L862-R949) * Added helper functions `_attributes_to_dict` and `_has_derived_labels` for legacy attribute parsing and hierarchy detection. **API and compatibility changes:** * Added the `RESTRICTED` value to the `GroupType` enum to support empty label groups. * Broadened type imports in `categories.py` for improved type hinting and compatibility. **Other codebase updates:** * Updated test imports in `test_categories.py` to include new hierarchical classes. * Simplified Polars conversion logic in `fields.py` for label serialization. These changes lay the groundwork for robust hierarchical and multi-label classification support in the experimental Datumaro API. ### Checklist <!-- Put an 'x' in all the boxes that apply --> - [ ] I have added tests to cover my changes or documented any manual tests. - [ ] I have updated the [documentation](https://github.com/open-edge-platform/datumaro/tree/develop/docs) accordingly --------- Signed-off-by: Albert van Houten <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 47e3576 commit 04948e7

File tree

5 files changed

+578
-21
lines changed

5 files changed

+578
-21
lines changed

src/datumaro/experimental/categories.py

Lines changed: 189 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from __future__ import annotations
1313

14+
from collections import defaultdict
1415
from dataclasses import dataclass, field
1516
from enum import IntEnum
1617
from functools import cache
@@ -33,6 +34,7 @@ class GroupType(IntEnum):
3334

3435
EXCLUSIVE = 0 # Only one label from the group can be assigned
3536
INCLUSIVE = 1 # Multiple labels from the group can be assigned
37+
RESTRICTED = 2 # For empty labels
3638

3739
def to_str(self) -> str:
3840
return self.name.lower()
@@ -58,7 +60,10 @@ class Categories:
5860

5961
@dataclass(frozen=True)
6062
class LabelCategories(Categories):
61-
"""Represents a group of labels with a specific group type and semantics."""
63+
"""
64+
Represents a group of labels with a specific group type and semantics.
65+
Use this for simple, non-hierarchical tasks.
66+
"""
6267

6368
labels: Tuple[str, ...] = field(default_factory=tuple)
6469
group_type: GroupType = GroupType.EXCLUSIVE
@@ -131,6 +136,189 @@ def __hash__(self):
131136
return hash((self.labels, self.group_type, frozenset(self.label_semantics.items())))
132137

133138

139+
@dataclass(frozen=True)
140+
class HierarchicalLabelCategory:
141+
"""Represents a single label category with hierarchical support."""
142+
143+
name: str
144+
parent: str = field(default="")
145+
label_semantics: dict = field(default_factory=dict)
146+
147+
def __post_init__(self):
148+
"""Validate that name is not empty."""
149+
if not self.name or not isinstance(self.name, str):
150+
raise ValueError("Label name cannot be empty and must be a string")
151+
152+
def __hash__(self):
153+
return hash((self.name, self.parent, frozenset(self.label_semantics.items())))
154+
155+
156+
@dataclass(frozen=True)
157+
class LabelGroup:
158+
"""Represents a group of labels with a specific group type."""
159+
160+
name: str
161+
labels: Tuple[str, ...] = field(default_factory=tuple)
162+
group_type: GroupType = GroupType.EXCLUSIVE
163+
164+
def __post_init__(self):
165+
"""Validate that name is not empty and labels is a tuple."""
166+
if not self.name or not isinstance(self.name, str):
167+
raise ValueError("Label group name cannot be empty and must be a string")
168+
169+
170+
@dataclass(frozen=True)
171+
class HierarchicalLabelCategories(Categories):
172+
"""
173+
Represents hierarchical label categories with groups and parent-child relationships.
174+
Use this for complex hierarchical classification tasks.
175+
"""
176+
177+
items: Tuple[HierarchicalLabelCategory, ...] = field(default_factory=tuple)
178+
label_groups: Tuple[LabelGroup, ...] = field(default_factory=tuple)
179+
label_semantics: dict = field(default_factory=dict)
180+
181+
def __post_init__(self):
182+
if not isinstance(self.items, tuple):
183+
raise TypeError("items must be a tuple of HierarchicalLabelCategory")
184+
if not isinstance(self.label_groups, tuple):
185+
raise TypeError("label_groups must be a tuple of LabelGroup")
186+
187+
# Validate no duplicate names
188+
seen_names = set()
189+
for item in self.items:
190+
if item.name in seen_names:
191+
raise ValueError(f"Duplicate label name: {item.name}")
192+
seen_names.add(item.name)
193+
194+
# Also check for name label_semantics
195+
# TODO: Remove this check after migrating completely to new system
196+
if "name" in item.label_semantics:
197+
name = item.label_semantics["name"]
198+
if name not in seen_names:
199+
seen_names.add(name)
200+
201+
# Validate that all parents exist
202+
for item in self.items:
203+
if item.parent and item.parent not in seen_names:
204+
raise ValueError(f"Parent '{item.parent}' not found for label '{item.name}'")
205+
206+
# Validate that all labels in groups exist
207+
for group in self.label_groups:
208+
for label_name in group.labels:
209+
if label_name not in seen_names:
210+
raise ValueError(
211+
f"Label '{label_name}' in group '{group.name}' not found in items"
212+
)
213+
214+
@property
215+
@cache
216+
def _index_map(self) -> Dict[str, int]:
217+
"""Cached mapping from label names to indices."""
218+
return {item.name: idx for idx, item in enumerate(self.items)}
219+
220+
@property
221+
@cache
222+
def _children_map(self) -> Dict[str, Tuple[str, ...]]:
223+
"""Cached mapping from parent names to child names."""
224+
children_map: defaultdict[str, List[str]] = defaultdict(list)
225+
for item in self.items:
226+
if item.parent:
227+
children_map[item.parent].append(item.name)
228+
return {parent: tuple(children) for parent, children in children_map.items()}
229+
230+
@property
231+
def labels(self) -> Tuple[str, ...]:
232+
"""Get all label names for compatibility."""
233+
return tuple(item.name for item in self.items)
234+
235+
def find(self, name: str) -> Tuple[Optional[int], Optional[HierarchicalLabelCategory]]:
236+
"""
237+
Find a label by name.
238+
239+
Args:
240+
name: The label name to find
241+
242+
Returns:
243+
A tuple of (index, category) or (None, None) if not found
244+
"""
245+
index = self._index_map.get(name)
246+
if index is not None:
247+
return index, self.items[index]
248+
return None, None
249+
250+
def get_children(self, parent_name: str) -> Tuple[str, ...]:
251+
"""
252+
Get all children of a parent label.
253+
254+
Args:
255+
parent_name: The name of the parent label
256+
257+
Returns:
258+
Tuple of child label names
259+
"""
260+
return self._children_map.get(parent_name, ())
261+
262+
def get_parent(self, label_name: str) -> Optional[str]:
263+
"""
264+
Get the parent of a label.
265+
266+
Args:
267+
label_name: The name of the label
268+
269+
Returns:
270+
Parent name or None if no parent
271+
"""
272+
index = self._index_map.get(label_name)
273+
if index is not None:
274+
return self.items[index].parent
275+
return None
276+
277+
def get_hierarchy_level(self, label_name: str) -> int:
278+
"""
279+
Get the hierarchy level of a label (0 for root, 1 for first level children, etc.)
280+
281+
Args:
282+
label_name: The name of the label
283+
284+
Returns:
285+
Hierarchy level
286+
"""
287+
level = 0
288+
current = label_name
289+
while True:
290+
parent = self.get_parent(current)
291+
if not parent:
292+
break
293+
level += 1
294+
current = parent
295+
return level
296+
297+
def __getitem__(self, idx: int) -> HierarchicalLabelCategory:
298+
"""Get category by index."""
299+
return self.items[idx]
300+
301+
def __contains__(self, value: Union[int, str]) -> bool:
302+
"""Check if a label exists by name or index."""
303+
if isinstance(value, str):
304+
return value in self._index_map
305+
else:
306+
return 0 <= value < len(self.items)
307+
308+
def __len__(self) -> int:
309+
"""Get the number of labels."""
310+
return len(self.items)
311+
312+
def __iter__(self) -> Iterator[HierarchicalLabelCategory]:
313+
"""Iterate over label categories."""
314+
return iter(self.items)
315+
316+
def __hash__(self):
317+
# Hash label_groups via value-based representation to avoid relying on their own hash implementation.
318+
lg_repr = tuple((lg.name, tuple(lg.labels), lg.group_type) for lg in self.label_groups)
319+
return hash((self.items, lg_repr, frozenset(self.label_semantics.items())))
320+
321+
134322
class RgbColor(NamedTuple):
135323
"""RGB color representation with named fields."""
136324

src/datumaro/experimental/fields.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -532,15 +532,7 @@ def to_polars_schema(self, name: str) -> dict[str, pl.DataType]:
532532

533533
def to_polars(self, name: str, value: Any) -> dict[str, pl.Series]:
534534
"""Convert label(s) to Polars format for single or multi-label cases."""
535-
pl_type = self._pl_type
536-
537-
if value is None:
538-
return {name: pl.Series(name, [None], dtype=pl_type)}
539-
540-
if self.multi_label:
541-
return {name: pl.Series(name, [to_numpy(value)], dtype=pl.List(self.dtype))}
542-
543-
return {name: pl.Series(name, [value], dtype=pl_type)}
535+
return {name: pl.Series(name, [value], dtype=self._pl_type)}
544536

545537
def from_polars(self, name: str, row_index: int, df: pl.DataFrame, target_type: type[T]) -> T:
546538
"""Reconstruct label(s) from Polars data."""

0 commit comments

Comments
 (0)