Skip to content

Commit 5c655e2

Browse files
authored
feat: Improve PydanticToHFDatasets.model_to_dict() (#239)
1 parent 04d2460 commit 5c655e2

File tree

2 files changed

+207
-141
lines changed

2 files changed

+207
-141
lines changed

sieves/tasks/utils.py

Lines changed: 149 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -3,205 +3,213 @@
33
from __future__ import annotations
44

55
import abc
6+
import types
67
import typing
78
from typing import Any
89

9-
# TODO: Suppress Pydantic deprecation warnings when importing this
1010
import datasets
1111
import pydantic
12-
import pydantic_core.core_schema
1312

1413

1514
class PydanticToHFDatasets(abc.ABC):
1615
"""Collection of utilities for converting Pydantic models (types and instances) to HF's `datasets.Dataset`."""
1716

17+
_PRIMITIVES_MAP: dict[type, str] = {
18+
str: "string",
19+
int: "int32",
20+
float: "float32",
21+
bool: "bool",
22+
}
23+
1824
@classmethod
1925
def model_cls_to_features(cls, entity_type: type[pydantic.BaseModel]) -> datasets.Features:
2026
"""Given a Pydantic model, build a `datasets.Sequence` of features that match its fields.
2127
22-
:param entity_type: Entity type
23-
:return: `datasets.Features` instance for use in HF `datasets.Dataset`.
28+
:param entity_type: The Pydantic model class to convert.
29+
:return: A `datasets.Features` instance for use in a Hugging Face `datasets.Dataset`.
2430
"""
25-
field_features: dict[str, datasets.Value] = {}
31+
field_features: dict[str, datasets.Value | datasets.Sequence | datasets.Features] = {}
2632

27-
# TODO Suppress warnings about model_fields access.
2833
for field_name, field_info in entity_type.model_fields.items():
29-
# field_info.annotation is e.g. str, list[str], MyNestedModel, etc.
30-
field_features[field_name] = cls._annotation_to_values(field_info.annotation) # type: ignore[arg-type]
34+
field_features[field_name] = cls._annotation_to_values(field_info.annotation)
3135

3236
return datasets.Features(field_features)
3337

3438
@classmethod
35-
def _annotation_to_values(
36-
cls, annotation: pydantic_core.core_schema.ModelField | type
37-
) -> datasets.Value | datasets.Sequence:
38-
"""Convert a type annotation (e.g. str, list[int], MyNestedModel) to a Hugging Face `datasets` feature.
39-
40-
Handles:
41-
- Basic python types (str, int, float, bool)
42-
- Lists/tuples (e.g. list[str], tuple[int], fallback for heterogeneous)
43-
- Dict[str, ...] => Sequence of { "key": str, "value": ... }
44-
- Nested Pydantic BaseModel
45-
- Union/Optional => fallback to string
46-
- Catch-all fallback => string
47-
48-
:param annotation: Annotation to convert.
49-
:return: `datasets.Value` or `datasets.Sequence` instance generated from specified annotation.
39+
def _annotation_to_values(cls, annotation: Any) -> datasets.Value | datasets.Sequence | datasets.Features:
40+
"""Convert a type annotation to a Hugging Face `datasets` feature.
41+
42+
:param annotation: The type annotation to convert (e.g., str, list[int], or a Pydantic model).
43+
:return: A Hugging Face dataset feature instance generated from the specified annotation.
5044
"""
5145
origin = typing.get_origin(annotation)
5246
args = typing.get_args(annotation)
5347

54-
# 1) If annotation is a subclass of BaseModel, recursively build features
48+
# 1) Nested Pydantic Model.
5549
if isinstance(annotation, type) and issubclass(annotation, pydantic.BaseModel):
5650
return cls.model_cls_to_features(annotation)
5751

58-
# 2) Handle list[...] or tuple[...]
52+
# 2) Sequences (list, tuple).
5953
if origin in (list, tuple):
60-
if len(args) == 1:
61-
# e.g. list[str], tuple[int]. noqa: ERA001
62-
item_type = args[0]
63-
return datasets.Sequence(cls._annotation_to_values(item_type))
64-
elif len(args) > 1 and origin is tuple:
65-
# e.g. tuple[str, int] => fallback to storing as string
66-
return datasets.Sequence(datasets.Value("string"))
67-
else:
68-
# fallback
69-
return datasets.Sequence(datasets.Value("string"))
70-
71-
# 3) Handle dict[...] => convert to sequence of { "key": str, "value": ... }
54+
return cls._handle_sequence_annotation(args)
55+
56+
# 3) Dictionaries.
7257
if origin is dict:
73-
# Typically we have 2 type args: key_type, value_type
74-
if len(args) == 2:
75-
key_type, value_type = args
76-
if key_type is str:
77-
# For dict[str, T], store as a sequence of key-value pairs
78-
return datasets.Sequence(
79-
feature=datasets.Features(
80-
{"key": datasets.Value("string"), "value": cls._annotation_to_values(value_type)}
81-
)
82-
)
83-
# If untyped or non-string keys, store as JSON string
84-
return datasets.Value("string")
85-
86-
# 4) If Union/Optional => fallback to string
87-
if origin == typing.Union:
88-
return datasets.Value("string")
89-
90-
# 5) Basic primitives. Fallback: store as string.
91-
primitives_map: dict[type | pydantic_core.core_schema.ModelField, str] = {
92-
str: "string",
93-
int: "int32",
94-
float: "float32",
95-
bool: "bool",
96-
}
97-
98-
return datasets.Value(primitives_map.get(annotation, "string"))
58+
return cls._handle_dict_annotation(args)
59+
60+
# 4) Union / Optional.
61+
if origin in (typing.Union, getattr(types, "UnionType", None)):
62+
return cls._handle_union_annotation(args)
63+
64+
# 5) Primitives & Fallback.
65+
return datasets.Value(cls._PRIMITIVES_MAP.get(annotation, "string"))
66+
67+
@classmethod
68+
def _handle_sequence_annotation(cls, args: tuple[Any, ...]) -> datasets.Sequence:
69+
"""Handle list[...] and tuple[...] annotations.
70+
71+
:param args: The type arguments of the sequence annotation.
72+
:return: A `datasets.Sequence` feature.
73+
"""
74+
if len(args) == 1:
75+
return datasets.Sequence(cls._annotation_to_values(args[0]))
76+
# Fallback for heterogeneous tuples or untyped lists.
77+
return datasets.Sequence(datasets.Value("string"))
9978

10079
@classmethod
101-
def model_to_dict(cls, model: pydantic.BaseModel | None) -> Any:
102-
"""Given a Pydantic model instance (or nested structure), return a Python object (dict, list, etc.).
103-
104-
Matchies the Hugging Face Features schema defined by `_pydantic_annotation_to_hf_value`.
105-
Handles:
106-
- BaseModel subclasses (recursively)
107-
- Lists / tuples
108-
- Dict[str, X] => list of {"key": str, "value": X}
109-
- Primitives
110-
- Union / fallback => string
111-
112-
:param model: Entity to convert.
113-
:return: Entity as dict aligned with the `datasets.Dataset` schema generated by
114-
`PydanticHFConverter._model_to_features`.
80+
def _handle_dict_annotation(cls, args: tuple[Any, ...]) -> datasets.Sequence | datasets.Value:
81+
"""Handle dict[...] annotations.
82+
83+
:param args: The type arguments of the dictionary annotation.
84+
:return: A `datasets.Sequence` for typed string-key dicts, or a `datasets.Value` fallback.
85+
"""
86+
if len(args) == 2 and args[0] is str:
87+
# For dict[str, T], store as a sequence of key-value pairs.
88+
return datasets.Sequence(
89+
feature=datasets.Features(
90+
{"key": datasets.Value("string"), "value": cls._annotation_to_values(args[1])}
91+
)
92+
)
93+
# Fallback for non-string keys or untyped dicts.
94+
return datasets.Value("string")
95+
96+
@classmethod
97+
def _handle_union_annotation(cls, args: tuple[Any, ...]) -> datasets.Value | datasets.Sequence | datasets.Features:
98+
"""Handle Union and Optional annotations.
99+
100+
:param args: The type arguments of the Union annotation.
101+
:return: The feature for the underlying type if Optional, otherwise a string fallback.
102+
"""
103+
underlying_type = cls._get_underlying_optional_type(args)
104+
if underlying_type:
105+
return cls._annotation_to_values(underlying_type)
106+
return datasets.Value("string")
107+
108+
@classmethod
109+
def model_to_dict(cls, model: pydantic.BaseModel | None) -> dict[str, Any] | None:
110+
"""Convert a Pydantic model instance to a dict aligned with the HF dataset schema.
111+
112+
:param model: The Pydantic model instance to convert.
113+
:return: A dictionary representation of the model instance, or None if the input is None.
115114
"""
116-
# 0) If `entity` is None or truly empty
117115
if model is None:
118116
return None
119117

120-
# 1) If it's an actual Pydantic model instance
121118
if isinstance(model, pydantic.BaseModel):
122-
out = {}
123-
# model_fields is a dict: field_name -> FieldInfo
124-
# We read each field's value from the instance
125-
for field_name, field_info in model.model_fields.items():
126-
annotation = field_info.annotation # e.g. str, list[int], SubModel
119+
out: dict[str, Any] = {}
120+
for field_name, field_info in type(model).model_fields.items():
127121
value = getattr(model, field_name)
128-
out[field_name] = cls._convert_value_for_dataset(value, annotation)
122+
out[field_name] = cls._convert_value_for_dataset(value, field_info.annotation)
129123
return out
130124

131-
# 2) If it’s not a model, we fallback to checking the type annotation dynamically or just returning the raw.
132-
# But typically you'd _call this function on the *top-level Pydantic model instance*.
133-
# For safety:
134-
return model # type: ignore[unreachable]
125+
return model # type: ignore[return-value]
135126

136127
@classmethod
137128
def _convert_value_for_dataset(cls, value: Any, annotation: Any) -> Any:
138-
"""Recursively convert a value (with its declared annotation) to something that fits the HF dataset row format.
129+
"""Recursively convert a value to something that fits the HF dataset row format.
139130
140-
Parallel to `_pydantic_annotation_to_hf_value`.
141-
142-
:param value: Value to convert.
143-
:param annotation: Type annotation of value.
144-
:return Any: Converted value.
131+
:param value: The value to convert.
132+
:param annotation: The type annotation associated with the value.
133+
:return: The converted value compatible with Hugging Face datasets.
145134
"""
146-
# Handle None or missing
147135
if value is None:
148136
return None
149137

150138
origin = typing.get_origin(annotation)
151139
args = typing.get_args(annotation)
152140

153-
# 1) Nested Pydantic model
141+
# 1) Nested Pydantic Model.
154142
if isinstance(value, pydantic.BaseModel):
155143
return cls.model_to_dict(value)
156144

157-
# 2) list[...] or tuple[...]
145+
# 2) Sequences (list, tuple).
158146
if origin in (list, tuple):
159-
# If it's actually a list/tuple, recursively process items
160-
if isinstance(value, list | tuple):
161-
if len(args) == 1:
162-
# e.g. list[str], list[SomeSubModel].
163-
item_type = args[0]
164-
return [cls._convert_value_for_dataset(v, item_type) for v in value]
165-
elif len(args) > 1 and origin is tuple:
166-
# tuple[str, int, ...] => fallback to string or handle partial
167-
return [str(v) for v in value]
168-
else:
169-
# fallback
170-
return [str(v) for v in value]
171-
else:
172-
# If the actual data isn't a list/tuple, fallback
173-
return str(value)
174-
175-
# 3) dict[str, X] => store as list of { "key": str, "value": X }
147+
return cls._handle_sequence_value(value, args)
148+
149+
# 3) Dictionaries.
176150
if origin is dict:
177-
# Check if the actual data is indeed a dict
178-
if isinstance(value, dict):
179-
if len(args) == 2:
180-
key_type, val_type = args
181-
# only handle str-key dicts
182-
if key_type is str:
183-
kv_list = []
184-
for k, v in value.items():
185-
# Convert each item recursively
186-
converted_val = cls._convert_value_for_dataset(v, val_type)
187-
kv_list.append({"key": str(k), "value": converted_val})
188-
return kv_list
189-
# else fallback -> store entire dict as a string
190-
return str(value)
191-
else:
192-
# Not actually a dict
193-
return str(value)
194-
195-
# 4) Unions / Optionals => fallback to string (or refine if you want)
196-
if origin == typing.Union:
197-
# Typically means `Optional[X]` or `Union[X, Y]`.
198-
# We'll just store it as string:
199-
return str(value)
151+
return cls._handle_dict_value(value, args)
200152

201-
# 5) If annotation is a direct primitive type
202-
# Just return the value as-is
153+
# 4) Union / Optional.
154+
if origin in (typing.Union, getattr(types, "UnionType", None)):
155+
return cls._handle_union_value(value, args)
156+
157+
# 5) Primitives & fallback.
203158
if annotation in (str, int, float, bool):
204159
return value
160+
return str(value)
161+
162+
@classmethod
163+
def _handle_sequence_value(cls, value: Any, args: tuple[Any, ...]) -> list[Any] | str:
164+
"""Handle sequence values.
165+
166+
:param value: The sequence value to convert.
167+
:param args: The type arguments of the sequence annotation.
168+
:return: A list of converted values, or a string representation fallback.
169+
"""
170+
if not isinstance(value, (list, tuple)): # noqa: UP038
171+
return str(value)
172+
173+
if len(args) == 1:
174+
return [cls._convert_value_for_dataset(v, args[0]) for v in value]
175+
176+
return [str(v) for v in value]
177+
178+
@classmethod
179+
def _handle_dict_value(cls, value: Any, args: tuple[Any, ...]) -> list[dict[str, Any]] | str:
180+
"""Handle dictionary values.
181+
182+
:param value: The dictionary value to convert.
183+
:param args: The type arguments of the dictionary annotation.
184+
:return: A list of key-value pair dictionaries, or a string representation fallback.
185+
"""
186+
if not isinstance(value, dict):
187+
return str(value)
188+
189+
if len(args) == 2 and args[0] is str:
190+
return [{"key": str(k), "value": cls._convert_value_for_dataset(v, args[1])} for k, v in value.items()]
205191

206-
# 6) If it's a fallback -> store as string
207192
return str(value)
193+
194+
@classmethod
195+
def _handle_union_value(cls, value: Any, args: tuple[Any, ...]) -> Any:
196+
"""Handle Union and Optional values.
197+
198+
:param value: The value to convert.
199+
:param args: The type arguments of the Union annotation.
200+
:return: The converted value if Optional, otherwise a string representation fallback.
201+
"""
202+
underlying_type = cls._get_underlying_optional_type(args)
203+
if underlying_type:
204+
return cls._convert_value_for_dataset(value, underlying_type)
205+
return str(value)
206+
207+
@classmethod
208+
def _get_underlying_optional_type(cls, args: tuple[Any, ...]) -> Any | None:
209+
"""Extract T from Optional[T] / Union[T, None].
210+
211+
:param args: The type arguments of the Union annotation.
212+
:return: The underlying type T if it is an Optional[T], otherwise None.
213+
"""
214+
non_none_args = [arg for arg in args if arg is not type(None)]
215+
return non_none_args[0] if len(non_none_args) == 1 else None

0 commit comments

Comments
 (0)