|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | 5 | import abc |
| 6 | +import types |
6 | 7 | import typing |
7 | 8 | from typing import Any |
8 | 9 |
|
9 | | -# TODO: Suppress Pydantic deprecation warnings when importing this |
10 | 10 | import datasets |
11 | 11 | import pydantic |
12 | | -import pydantic_core.core_schema |
13 | 12 |
|
14 | 13 |
|
15 | 14 | class PydanticToHFDatasets(abc.ABC): |
16 | 15 | """Collection of utilities for converting Pydantic models (types and instances) to HF's `datasets.Dataset`.""" |
17 | 16 |
|
| 17 | + _PRIMITIVES_MAP: dict[type, str] = { |
| 18 | + str: "string", |
| 19 | + int: "int32", |
| 20 | + float: "float32", |
| 21 | + bool: "bool", |
| 22 | + } |
| 23 | + |
18 | 24 | @classmethod |
19 | 25 | def model_cls_to_features(cls, entity_type: type[pydantic.BaseModel]) -> datasets.Features: |
20 | 26 | """Given a Pydantic model, build a `datasets.Sequence` of features that match its fields. |
21 | 27 |
|
22 | | - :param entity_type: Entity type |
23 | | - :return: `datasets.Features` instance for use in HF `datasets.Dataset`. |
| 28 | + :param entity_type: The Pydantic model class to convert. |
| 29 | + :return: A `datasets.Features` instance for use in a Hugging Face `datasets.Dataset`. |
24 | 30 | """ |
25 | | - field_features: dict[str, datasets.Value] = {} |
| 31 | + field_features: dict[str, datasets.Value | datasets.Sequence | datasets.Features] = {} |
26 | 32 |
|
27 | | - # TODO Suppress warnings about model_fields access. |
28 | 33 | for field_name, field_info in entity_type.model_fields.items(): |
29 | | - # field_info.annotation is e.g. str, list[str], MyNestedModel, etc. |
30 | | - field_features[field_name] = cls._annotation_to_values(field_info.annotation) # type: ignore[arg-type] |
| 34 | + field_features[field_name] = cls._annotation_to_values(field_info.annotation) |
31 | 35 |
|
32 | 36 | return datasets.Features(field_features) |
33 | 37 |
|
34 | 38 | @classmethod |
35 | | - def _annotation_to_values( |
36 | | - cls, annotation: pydantic_core.core_schema.ModelField | type |
37 | | - ) -> datasets.Value | datasets.Sequence: |
38 | | - """Convert a type annotation (e.g. str, list[int], MyNestedModel) to a Hugging Face `datasets` feature. |
39 | | -
|
40 | | - Handles: |
41 | | - - Basic python types (str, int, float, bool) |
42 | | - - Lists/tuples (e.g. list[str], tuple[int], fallback for heterogeneous) |
43 | | - - Dict[str, ...] => Sequence of { "key": str, "value": ... } |
44 | | - - Nested Pydantic BaseModel |
45 | | - - Union/Optional => fallback to string |
46 | | - - Catch-all fallback => string |
47 | | -
|
48 | | - :param annotation: Annotation to convert. |
49 | | - :return: `datasets.Value` or `datasets.Sequence` instance generated from specified annotation. |
| 39 | + def _annotation_to_values(cls, annotation: Any) -> datasets.Value | datasets.Sequence | datasets.Features: |
| 40 | + """Convert a type annotation to a Hugging Face `datasets` feature. |
| 41 | +
|
| 42 | + :param annotation: The type annotation to convert (e.g., str, list[int], or a Pydantic model). |
| 43 | + :return: A Hugging Face dataset feature instance generated from the specified annotation. |
50 | 44 | """ |
51 | 45 | origin = typing.get_origin(annotation) |
52 | 46 | args = typing.get_args(annotation) |
53 | 47 |
|
54 | | - # 1) If annotation is a subclass of BaseModel, recursively build features |
| 48 | + # 1) Nested Pydantic Model. |
55 | 49 | if isinstance(annotation, type) and issubclass(annotation, pydantic.BaseModel): |
56 | 50 | return cls.model_cls_to_features(annotation) |
57 | 51 |
|
58 | | - # 2) Handle list[...] or tuple[...] |
| 52 | + # 2) Sequences (list, tuple). |
59 | 53 | if origin in (list, tuple): |
60 | | - if len(args) == 1: |
61 | | - # e.g. list[str], tuple[int]. noqa: ERA001 |
62 | | - item_type = args[0] |
63 | | - return datasets.Sequence(cls._annotation_to_values(item_type)) |
64 | | - elif len(args) > 1 and origin is tuple: |
65 | | - # e.g. tuple[str, int] => fallback to storing as string |
66 | | - return datasets.Sequence(datasets.Value("string")) |
67 | | - else: |
68 | | - # fallback |
69 | | - return datasets.Sequence(datasets.Value("string")) |
70 | | - |
71 | | - # 3) Handle dict[...] => convert to sequence of { "key": str, "value": ... } |
| 54 | + return cls._handle_sequence_annotation(args) |
| 55 | + |
| 56 | + # 3) Dictionaries. |
72 | 57 | if origin is dict: |
73 | | - # Typically we have 2 type args: key_type, value_type |
74 | | - if len(args) == 2: |
75 | | - key_type, value_type = args |
76 | | - if key_type is str: |
77 | | - # For dict[str, T], store as a sequence of key-value pairs |
78 | | - return datasets.Sequence( |
79 | | - feature=datasets.Features( |
80 | | - {"key": datasets.Value("string"), "value": cls._annotation_to_values(value_type)} |
81 | | - ) |
82 | | - ) |
83 | | - # If untyped or non-string keys, store as JSON string |
84 | | - return datasets.Value("string") |
85 | | - |
86 | | - # 4) If Union/Optional => fallback to string |
87 | | - if origin == typing.Union: |
88 | | - return datasets.Value("string") |
89 | | - |
90 | | - # 5) Basic primitives. Fallback: store as string. |
91 | | - primitives_map: dict[type | pydantic_core.core_schema.ModelField, str] = { |
92 | | - str: "string", |
93 | | - int: "int32", |
94 | | - float: "float32", |
95 | | - bool: "bool", |
96 | | - } |
97 | | - |
98 | | - return datasets.Value(primitives_map.get(annotation, "string")) |
| 58 | + return cls._handle_dict_annotation(args) |
| 59 | + |
| 60 | + # 4) Union / Optional. |
| 61 | + if origin in (typing.Union, getattr(types, "UnionType", None)): |
| 62 | + return cls._handle_union_annotation(args) |
| 63 | + |
| 64 | + # 5) Primitives & Fallback. |
| 65 | + return datasets.Value(cls._PRIMITIVES_MAP.get(annotation, "string")) |
| 66 | + |
| 67 | + @classmethod |
| 68 | + def _handle_sequence_annotation(cls, args: tuple[Any, ...]) -> datasets.Sequence: |
| 69 | + """Handle list[...] and tuple[...] annotations. |
| 70 | +
|
| 71 | + :param args: The type arguments of the sequence annotation. |
| 72 | + :return: A `datasets.Sequence` feature. |
| 73 | + """ |
| 74 | + if len(args) == 1: |
| 75 | + return datasets.Sequence(cls._annotation_to_values(args[0])) |
| 76 | + # Fallback for heterogeneous tuples or untyped lists. |
| 77 | + return datasets.Sequence(datasets.Value("string")) |
99 | 78 |
|
100 | 79 | @classmethod |
101 | | - def model_to_dict(cls, model: pydantic.BaseModel | None) -> Any: |
102 | | - """Given a Pydantic model instance (or nested structure), return a Python object (dict, list, etc.). |
103 | | -
|
104 | | - Matchies the Hugging Face Features schema defined by `_pydantic_annotation_to_hf_value`. |
105 | | - Handles: |
106 | | - - BaseModel subclasses (recursively) |
107 | | - - Lists / tuples |
108 | | - - Dict[str, X] => list of {"key": str, "value": X} |
109 | | - - Primitives |
110 | | - - Union / fallback => string |
111 | | -
|
112 | | - :param model: Entity to convert. |
113 | | - :return: Entity as dict aligned with the `datasets.Dataset` schema generated by |
114 | | - `PydanticHFConverter._model_to_features`. |
| 80 | + def _handle_dict_annotation(cls, args: tuple[Any, ...]) -> datasets.Sequence | datasets.Value: |
| 81 | + """Handle dict[...] annotations. |
| 82 | +
|
| 83 | + :param args: The type arguments of the dictionary annotation. |
| 84 | + :return: A `datasets.Sequence` for typed string-key dicts, or a `datasets.Value` fallback. |
| 85 | + """ |
| 86 | + if len(args) == 2 and args[0] is str: |
| 87 | + # For dict[str, T], store as a sequence of key-value pairs. |
| 88 | + return datasets.Sequence( |
| 89 | + feature=datasets.Features( |
| 90 | + {"key": datasets.Value("string"), "value": cls._annotation_to_values(args[1])} |
| 91 | + ) |
| 92 | + ) |
| 93 | + # Fallback for non-string keys or untyped dicts. |
| 94 | + return datasets.Value("string") |
| 95 | + |
| 96 | + @classmethod |
| 97 | + def _handle_union_annotation(cls, args: tuple[Any, ...]) -> datasets.Value | datasets.Sequence | datasets.Features: |
| 98 | + """Handle Union and Optional annotations. |
| 99 | +
|
| 100 | + :param args: The type arguments of the Union annotation. |
| 101 | + :return: The feature for the underlying type if Optional, otherwise a string fallback. |
| 102 | + """ |
| 103 | + underlying_type = cls._get_underlying_optional_type(args) |
| 104 | + if underlying_type: |
| 105 | + return cls._annotation_to_values(underlying_type) |
| 106 | + return datasets.Value("string") |
| 107 | + |
| 108 | + @classmethod |
| 109 | + def model_to_dict(cls, model: pydantic.BaseModel | None) -> dict[str, Any] | None: |
| 110 | + """Convert a Pydantic model instance to a dict aligned with the HF dataset schema. |
| 111 | +
|
| 112 | + :param model: The Pydantic model instance to convert. |
| 113 | + :return: A dictionary representation of the model instance, or None if the input is None. |
115 | 114 | """ |
116 | | - # 0) If `entity` is None or truly empty |
117 | 115 | if model is None: |
118 | 116 | return None |
119 | 117 |
|
120 | | - # 1) If it's an actual Pydantic model instance |
121 | 118 | if isinstance(model, pydantic.BaseModel): |
122 | | - out = {} |
123 | | - # model_fields is a dict: field_name -> FieldInfo |
124 | | - # We read each field's value from the instance |
125 | | - for field_name, field_info in model.model_fields.items(): |
126 | | - annotation = field_info.annotation # e.g. str, list[int], SubModel |
| 119 | + out: dict[str, Any] = {} |
| 120 | + for field_name, field_info in type(model).model_fields.items(): |
127 | 121 | value = getattr(model, field_name) |
128 | | - out[field_name] = cls._convert_value_for_dataset(value, annotation) |
| 122 | + out[field_name] = cls._convert_value_for_dataset(value, field_info.annotation) |
129 | 123 | return out |
130 | 124 |
|
131 | | - # 2) If it’s not a model, we fallback to checking the type annotation dynamically or just returning the raw. |
132 | | - # But typically you'd _call this function on the *top-level Pydantic model instance*. |
133 | | - # For safety: |
134 | | - return model # type: ignore[unreachable] |
| 125 | + return model # type: ignore[return-value] |
135 | 126 |
|
136 | 127 | @classmethod |
137 | 128 | def _convert_value_for_dataset(cls, value: Any, annotation: Any) -> Any: |
138 | | - """Recursively convert a value (with its declared annotation) to something that fits the HF dataset row format. |
| 129 | + """Recursively convert a value to something that fits the HF dataset row format. |
139 | 130 |
|
140 | | - Parallel to `_pydantic_annotation_to_hf_value`. |
141 | | -
|
142 | | - :param value: Value to convert. |
143 | | - :param annotation: Type annotation of value. |
144 | | - :return Any: Converted value. |
| 131 | + :param value: The value to convert. |
| 132 | + :param annotation: The type annotation associated with the value. |
| 133 | + :return: The converted value compatible with Hugging Face datasets. |
145 | 134 | """ |
146 | | - # Handle None or missing |
147 | 135 | if value is None: |
148 | 136 | return None |
149 | 137 |
|
150 | 138 | origin = typing.get_origin(annotation) |
151 | 139 | args = typing.get_args(annotation) |
152 | 140 |
|
153 | | - # 1) Nested Pydantic model |
| 141 | + # 1) Nested Pydantic Model. |
154 | 142 | if isinstance(value, pydantic.BaseModel): |
155 | 143 | return cls.model_to_dict(value) |
156 | 144 |
|
157 | | - # 2) list[...] or tuple[...] |
| 145 | + # 2) Sequences (list, tuple). |
158 | 146 | if origin in (list, tuple): |
159 | | - # If it's actually a list/tuple, recursively process items |
160 | | - if isinstance(value, list | tuple): |
161 | | - if len(args) == 1: |
162 | | - # e.g. list[str], list[SomeSubModel]. |
163 | | - item_type = args[0] |
164 | | - return [cls._convert_value_for_dataset(v, item_type) for v in value] |
165 | | - elif len(args) > 1 and origin is tuple: |
166 | | - # tuple[str, int, ...] => fallback to string or handle partial |
167 | | - return [str(v) for v in value] |
168 | | - else: |
169 | | - # fallback |
170 | | - return [str(v) for v in value] |
171 | | - else: |
172 | | - # If the actual data isn't a list/tuple, fallback |
173 | | - return str(value) |
174 | | - |
175 | | - # 3) dict[str, X] => store as list of { "key": str, "value": X } |
| 147 | + return cls._handle_sequence_value(value, args) |
| 148 | + |
| 149 | + # 3) Dictionaries. |
176 | 150 | if origin is dict: |
177 | | - # Check if the actual data is indeed a dict |
178 | | - if isinstance(value, dict): |
179 | | - if len(args) == 2: |
180 | | - key_type, val_type = args |
181 | | - # only handle str-key dicts |
182 | | - if key_type is str: |
183 | | - kv_list = [] |
184 | | - for k, v in value.items(): |
185 | | - # Convert each item recursively |
186 | | - converted_val = cls._convert_value_for_dataset(v, val_type) |
187 | | - kv_list.append({"key": str(k), "value": converted_val}) |
188 | | - return kv_list |
189 | | - # else fallback -> store entire dict as a string |
190 | | - return str(value) |
191 | | - else: |
192 | | - # Not actually a dict |
193 | | - return str(value) |
194 | | - |
195 | | - # 4) Unions / Optionals => fallback to string (or refine if you want) |
196 | | - if origin == typing.Union: |
197 | | - # Typically means `Optional[X]` or `Union[X, Y]`. |
198 | | - # We'll just store it as string: |
199 | | - return str(value) |
| 151 | + return cls._handle_dict_value(value, args) |
200 | 152 |
|
201 | | - # 5) If annotation is a direct primitive type |
202 | | - # Just return the value as-is |
| 153 | + # 4) Union / Optional. |
| 154 | + if origin in (typing.Union, getattr(types, "UnionType", None)): |
| 155 | + return cls._handle_union_value(value, args) |
| 156 | + |
| 157 | + # 5) Primitives & fallback. |
203 | 158 | if annotation in (str, int, float, bool): |
204 | 159 | return value |
| 160 | + return str(value) |
| 161 | + |
| 162 | + @classmethod |
| 163 | + def _handle_sequence_value(cls, value: Any, args: tuple[Any, ...]) -> list[Any] | str: |
| 164 | + """Handle sequence values. |
| 165 | +
|
| 166 | + :param value: The sequence value to convert. |
| 167 | + :param args: The type arguments of the sequence annotation. |
| 168 | + :return: A list of converted values, or a string representation fallback. |
| 169 | + """ |
| 170 | + if not isinstance(value, (list, tuple)): # noqa: UP038 |
| 171 | + return str(value) |
| 172 | + |
| 173 | + if len(args) == 1: |
| 174 | + return [cls._convert_value_for_dataset(v, args[0]) for v in value] |
| 175 | + |
| 176 | + return [str(v) for v in value] |
| 177 | + |
| 178 | + @classmethod |
| 179 | + def _handle_dict_value(cls, value: Any, args: tuple[Any, ...]) -> list[dict[str, Any]] | str: |
| 180 | + """Handle dictionary values. |
| 181 | +
|
| 182 | + :param value: The dictionary value to convert. |
| 183 | + :param args: The type arguments of the dictionary annotation. |
| 184 | + :return: A list of key-value pair dictionaries, or a string representation fallback. |
| 185 | + """ |
| 186 | + if not isinstance(value, dict): |
| 187 | + return str(value) |
| 188 | + |
| 189 | + if len(args) == 2 and args[0] is str: |
| 190 | + return [{"key": str(k), "value": cls._convert_value_for_dataset(v, args[1])} for k, v in value.items()] |
205 | 191 |
|
206 | | - # 6) If it's a fallback -> store as string |
207 | 192 | return str(value) |
| 193 | + |
| 194 | + @classmethod |
| 195 | + def _handle_union_value(cls, value: Any, args: tuple[Any, ...]) -> Any: |
| 196 | + """Handle Union and Optional values. |
| 197 | +
|
| 198 | + :param value: The value to convert. |
| 199 | + :param args: The type arguments of the Union annotation. |
| 200 | + :return: The converted value if Optional, otherwise a string representation fallback. |
| 201 | + """ |
| 202 | + underlying_type = cls._get_underlying_optional_type(args) |
| 203 | + if underlying_type: |
| 204 | + return cls._convert_value_for_dataset(value, underlying_type) |
| 205 | + return str(value) |
| 206 | + |
| 207 | + @classmethod |
| 208 | + def _get_underlying_optional_type(cls, args: tuple[Any, ...]) -> Any | None: |
| 209 | + """Extract T from Optional[T] / Union[T, None]. |
| 210 | +
|
| 211 | + :param args: The type arguments of the Union annotation. |
| 212 | + :return: The underlying type T if it is an Optional[T], otherwise None. |
| 213 | + """ |
| 214 | + non_none_args = [arg for arg in args if arg is not type(None)] |
| 215 | + return non_none_args[0] if len(non_none_args) == 1 else None |
0 commit comments