Skip to content

Commit e7e3b42

Browse files
AlbertvanHoutenCopilotCopilot
authored
Implement dataset export/import (#1946)
This pull request adds support to import and export datasets. An exported dataset will include all the images, where references to callables for images/masks are saved as PNG and image paths will be copied. The exported dataset also includes metadata, such as the categories and schema of the dataset and a version number in a JSON file. The polars dataframe is stored in parquet. Additionally, there is an option to export everything in a zip. This zip can also be imported as is without the need to manually extract it first. **Serialization and Deserialization Enhancements** * Added `to_dict` and `from_dict` methods to `Categories`, `LabelCategories`, `HierarchicalLabelCategories`, and `MaskCategories`, enabling polymorphic serialization and reconstruction of category objects. [[1]](diffhunk://#diff-f0858b860d68536c04476766ef0512d873b7f7cc20dde377a42ec13a836d9622L58-R101) [[2]](diffhunk://#diff-f0858b860d68536c04476766ef0512d873b7f7cc20dde377a42ec13a836d9622R181-R223) [[3]](diffhunk://#diff-f0858b860d68536c04476766ef0512d873b7f7cc20dde377a42ec13a836d9622R407-R496) [[4]](diffhunk://#diff-f0858b860d68536c04476766ef0512d873b7f7cc20dde377a42ec13a836d9622R565-R601) * Implemented `to_dict` and `from_dict` methods for the `Field` base class, using dataclass introspection to automatically serialize and reconstruct field attributes, including special handling for semantic and dtype fields. * Added serialization and deserialization logic to the `Schema` class, including storing type information and reconstructing attribute types via module introspection. **Dataset Import/Export API** * Exposed `export_dataset` and `import_dataset` functions at the module level, and added `export` and `from_file` methods to the `Dataset` class for saving/loading datasets in a structured format. [[1]](diffhunk://#diff-6c66a059cb4075cd7afbc29b4f34236f8101a1a47cdcc429c7a2d25868cc74b5R8) [[2]](diffhunk://#diff-4ac196ddc4dc8e6d33daf684ded18886ff8774fadb8b6cbd4bfa88ca424bb34fR500-R554) These changes lay the groundwork for robust dataset interchange and future compatibility across different category and schema types. <!-- Contributing guide: https://github.com/open-edge-platform/datumaro/blob/develop/contributing.md --> <!-- Please add a summary of changes. You may use Copilot to auto-generate the PR description but please consider including any other relevant facts which Copilot may be unaware of (such as design choices and testing procedure). Add references to the relevant issues and pull requests if any like so: Resolves #111 and #222. Depends on #1000 (for series of dependent commits). --> ### Checklist <!-- Put an 'x' in all the boxes that apply --> - [ ] I have added tests to cover my changes or documented any manual tests. - [ ] I have updated the [documentation](https://github.com/open-edge-platform/datumaro/tree/develop/docs) accordingly --------- Signed-off-by: Albert van Houten <[email protected]> Co-authored-by: Copilot <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 411dcb5 commit e7e3b42

File tree

7 files changed

+2458
-5
lines changed

7 files changed

+2458
-5
lines changed

src/datumaro/experimental/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from . import converters # Import converters to register them
66
from .converter_registry import ConverterRegistry, converter, find_conversion_path
77
from .dataset import Dataset, Sample
8+
from .export_import import export_dataset, import_dataset
89
from .fields import (
910
BBoxField,
1011
ImageCallableField,

src/datumaro/experimental/categories.py

Lines changed: 215 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from dataclasses import dataclass, field
1616
from enum import IntEnum
1717
from functools import cache
18-
from typing import Dict, Iterator, List, NamedTuple, Optional, Tuple, Union
18+
from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Tuple, Union
1919

2020

2121
class LabelSemantic(IntEnum):
@@ -55,7 +55,50 @@ class Categories:
5555
label attributes etc.
5656
"""
5757

58-
pass
58+
def to_dict(self) -> Dict[str, Any]:
59+
"""
60+
Serialize this Categories instance to a JSON-compatible dictionary.
61+
62+
This default implementation should be overridden by subclasses that have
63+
specific serialization needs. Returns just the type by default.
64+
65+
Returns:
66+
Dictionary representation of this Categories instance
67+
"""
68+
return {"type": self.__class__.__name__}
69+
70+
@classmethod
71+
def from_dict(cls, data: Dict[str, Any]) -> "Categories":
72+
"""
73+
Deserialize a Categories instance from a JSON dictionary.
74+
75+
This method uses polymorphic dispatch to create the correct subclass
76+
based on the "type" field in the dictionary.
77+
78+
Args:
79+
data: Dictionary containing serialized Categories data
80+
81+
Returns:
82+
Reconstructed Categories instance of the appropriate subclass
83+
"""
84+
cat_type = data.get("type")
85+
86+
if not cat_type:
87+
raise ValueError("Categories dictionary must have a 'type' field")
88+
89+
# Import subclasses to make them available
90+
subclass_map = {
91+
"LabelCategories": LabelCategories,
92+
"HierarchicalLabelCategories": HierarchicalLabelCategories,
93+
"MaskCategories": MaskCategories,
94+
}
95+
96+
if cat_type in subclass_map:
97+
return subclass_map[cat_type].from_dict(data)
98+
else:
99+
# Unknown type - return base Categories with just the type info
100+
# This allows forward compatibility with new category types
101+
raise ValueError(f"Unknown categories type: {cat_type}")
59102

60103

61104
@dataclass(frozen=True)
@@ -135,6 +178,49 @@ def __hash__(self):
135178
# Include label_semantics in the hash
136179
return hash((self.labels, self.group_type, frozenset(self.label_semantics.items())))
137180

181+
def to_dict(self) -> Dict[str, Any]:
182+
"""
183+
Serialize to a JSON-compatible dictionary.
184+
185+
Returns:
186+
Dictionary representation of this LabelCategories instance
187+
"""
188+
return {
189+
"type": "LabelCategories",
190+
"labels": list(self.labels),
191+
"group_type": self.group_type.name,
192+
"label_semantics": {
193+
k.name if isinstance(k, LabelSemantic) else str(k): v
194+
for k, v in self.label_semantics.items()
195+
},
196+
}
197+
198+
@classmethod
199+
def from_dict(cls, data: Dict[str, Any]) -> "LabelCategories":
200+
"""
201+
Deserialize from a JSON dictionary.
202+
203+
Args:
204+
data: Dictionary containing serialized LabelCategories data
205+
206+
Returns:
207+
Reconstructed LabelCategories instance
208+
"""
209+
# Reconstruct label_semantics with proper LabelSemantic keys
210+
label_semantics = {}
211+
for k, v in data.get("label_semantics", {}).items():
212+
try:
213+
key = LabelSemantic[k]
214+
except KeyError:
215+
key = k
216+
label_semantics[key] = v
217+
218+
return cls(
219+
labels=tuple(data["labels"]),
220+
group_type=GroupType[data["group_type"]],
221+
label_semantics=label_semantics,
222+
)
223+
138224

139225
@dataclass(frozen=True)
140226
class HierarchicalLabelCategory:
@@ -318,6 +404,96 @@ def __hash__(self):
318404
lg_repr = tuple((lg.name, tuple(lg.labels), lg.group_type) for lg in self.label_groups)
319405
return hash((self.items, lg_repr, frozenset(self.label_semantics.items())))
320406

407+
def to_dict(self) -> Dict[str, Any]:
408+
"""
409+
Serialize to a JSON-compatible dictionary.
410+
411+
Returns:
412+
Dictionary representation of this HierarchicalLabelCategories instance
413+
"""
414+
return {
415+
"type": "HierarchicalLabelCategories",
416+
"items": [
417+
{
418+
"name": item.name,
419+
"parent": item.parent,
420+
"label_semantics": {
421+
k.name if isinstance(k, LabelSemantic) else str(k): v
422+
for k, v in item.label_semantics.items()
423+
},
424+
}
425+
for item in self.items
426+
],
427+
"label_groups": [
428+
{
429+
"name": group.name,
430+
"labels": list(group.labels),
431+
"group_type": group.group_type.name,
432+
}
433+
for group in self.label_groups
434+
],
435+
"label_semantics": {
436+
k.name if isinstance(k, LabelSemantic) else str(k): v
437+
for k, v in self.label_semantics.items()
438+
},
439+
}
440+
441+
@classmethod
442+
def from_dict(cls, data: Dict[str, Any]) -> "HierarchicalLabelCategories":
443+
"""
444+
Deserialize from a JSON dictionary.
445+
446+
Args:
447+
data: Dictionary containing serialized HierarchicalLabelCategories data
448+
449+
Returns:
450+
Reconstructed HierarchicalLabelCategories instance
451+
"""
452+
# Reconstruct items
453+
items = []
454+
for item_dict in data["items"]:
455+
item_label_semantics = {}
456+
for k, v in item_dict.get("label_semantics", {}).items():
457+
try:
458+
key = LabelSemantic[k]
459+
except KeyError:
460+
key = k
461+
item_label_semantics[key] = v
462+
463+
items.append(
464+
HierarchicalLabelCategory(
465+
name=item_dict["name"],
466+
parent=item_dict.get("parent", ""),
467+
label_semantics=item_label_semantics,
468+
)
469+
)
470+
471+
# Reconstruct label groups
472+
label_groups = []
473+
for group_dict in data.get("label_groups", []):
474+
label_groups.append(
475+
LabelGroup(
476+
name=group_dict["name"],
477+
labels=tuple(group_dict["labels"]),
478+
group_type=GroupType[group_dict["group_type"]],
479+
)
480+
)
481+
482+
# Reconstruct label_semantics
483+
label_semantics = {}
484+
for k, v in data.get("label_semantics", {}).items():
485+
try:
486+
key = LabelSemantic[k]
487+
except KeyError:
488+
key = k
489+
label_semantics[key] = v
490+
491+
return cls(
492+
items=tuple(items),
493+
label_groups=tuple(label_groups),
494+
label_semantics=label_semantics,
495+
)
496+
321497

322498
class RgbColor(NamedTuple):
323499
"""RGB color representation with named fields."""
@@ -386,6 +562,43 @@ class MaskCategories(Categories):
386562
def __hash__(self):
387563
return hash((tuple(self.labels), frozenset(self.colormap.items())))
388564

565+
def to_dict(self) -> Dict[str, Any]:
566+
"""
567+
Serialize to a JSON-compatible dictionary.
568+
569+
Returns:
570+
Dictionary representation of this MaskCategories instance
571+
"""
572+
return {
573+
"type": "MaskCategories",
574+
"labels": list(self.labels),
575+
"colormap": {
576+
str(idx): [color.r, color.g, color.b] for idx, color in self.colormap.data.items()
577+
},
578+
}
579+
580+
@classmethod
581+
def from_dict(cls, data: Dict[str, Any]) -> "MaskCategories":
582+
"""
583+
Deserialize from a JSON dictionary.
584+
585+
Args:
586+
data: Dictionary containing serialized MaskCategories data
587+
588+
Returns:
589+
Reconstructed MaskCategories instance
590+
"""
591+
labels = data.get("labels", [])
592+
593+
# Reconstruct colormap
594+
colormap_data = {}
595+
for idx_str, color_list in data.get("colormap", {}).items():
596+
idx = int(idx_str)
597+
colormap_data[idx] = RgbColor(*color_list)
598+
599+
colormap = Colormap(data=colormap_data)
600+
return cls(labels=labels, colormap=colormap)
601+
389602
@classmethod
390603
def generate(cls, size: int = 255, include_background: bool = True) -> "MaskCategories":
391604
"""

0 commit comments

Comments
 (0)