|
1 | 1 | import json |
| 2 | +from enum import Enum |
2 | 3 | from io import BytesIO |
3 | 4 | from pathlib import Path |
4 | 5 | from typing import Dict, List, Optional, Union |
|
19 | 20 | seg_adapter = TypeAdapter(Dict[int, SegmentedPage]) |
20 | 21 |
|
21 | 22 |
|
| 23 | +class FieldType(Enum): |
| 24 | + STRING = "string" |
| 25 | + BINARY = "binary" |
| 26 | + IMAGE_LIST = "image_list" |
| 27 | + STRING_LIST = "string_list" |
| 28 | + |
| 29 | + |
| 30 | +class SchemaGenerator: |
| 31 | + """Generates both HuggingFace Features and PyArrow schemas from a field definition.""" |
| 32 | + |
| 33 | + @staticmethod |
| 34 | + def _get_features_type(field_type: FieldType): |
| 35 | + mapping = { |
| 36 | + FieldType.STRING: Value("string"), |
| 37 | + FieldType.BINARY: Value("binary"), |
| 38 | + FieldType.IMAGE_LIST: Sequence(Features_Image()), |
| 39 | + FieldType.STRING_LIST: Sequence(Value("string")), |
| 40 | + } |
| 41 | + return mapping[field_type] |
| 42 | + |
| 43 | + @staticmethod |
| 44 | + def _get_pyarrow_type(field_type: FieldType): |
| 45 | + import pyarrow as pa |
| 46 | + |
| 47 | + image_type = pa.struct([("bytes", pa.binary()), ("path", pa.string())]) |
| 48 | + |
| 49 | + mapping = { |
| 50 | + FieldType.STRING: pa.string(), |
| 51 | + FieldType.BINARY: pa.binary(), |
| 52 | + FieldType.IMAGE_LIST: pa.list_(image_type), |
| 53 | + FieldType.STRING_LIST: pa.list_(pa.string()), |
| 54 | + } |
| 55 | + return mapping[field_type] |
| 56 | + |
| 57 | + @classmethod |
| 58 | + def generate_features(cls, field_definitions: Dict[str, FieldType]) -> Features: |
| 59 | + return Features( |
| 60 | + { |
| 61 | + field_name: cls._get_features_type(field_type) |
| 62 | + for field_name, field_type in field_definitions.items() |
| 63 | + } |
| 64 | + ) |
| 65 | + |
| 66 | + @classmethod |
| 67 | + def generate_pyarrow_schema(cls, field_definitions: Dict[str, FieldType]): |
| 68 | + import pyarrow as pa |
| 69 | + |
| 70 | + return pa.schema( |
| 71 | + [ |
| 72 | + (field_name, cls._get_pyarrow_type(field_type)) |
| 73 | + for field_name, field_type in field_definitions.items() |
| 74 | + ] |
| 75 | + ) |
| 76 | + |
| 77 | + |
22 | 78 | class DatasetRecord( |
23 | 79 | BaseModel |
24 | 80 | ): # TODO make predictionrecord class, factor prediction-related fields there. |
@@ -51,26 +107,30 @@ class DatasetRecord( |
51 | 107 | def get_field_alias(cls, field_name: str) -> str: |
52 | 108 | return cls.model_fields[field_name].alias or field_name |
53 | 109 |
|
| 110 | + @classmethod |
| 111 | + def _get_field_definitions(cls) -> Dict[str, FieldType]: |
| 112 | + """Define the schema for this class. Override in subclasses to extend.""" |
| 113 | + return { |
| 114 | + cls.get_field_alias("doc_id"): FieldType.STRING, |
| 115 | + cls.get_field_alias("doc_path"): FieldType.STRING, |
| 116 | + cls.get_field_alias("doc_hash"): FieldType.STRING, |
| 117 | + cls.get_field_alias("ground_truth_doc"): FieldType.STRING, |
| 118 | + cls.get_field_alias("ground_truth_segmented_pages"): FieldType.STRING, |
| 119 | + cls.get_field_alias("ground_truth_pictures"): FieldType.IMAGE_LIST, |
| 120 | + cls.get_field_alias("ground_truth_page_images"): FieldType.IMAGE_LIST, |
| 121 | + cls.get_field_alias("original"): FieldType.BINARY, |
| 122 | + cls.get_field_alias("mime_type"): FieldType.STRING, |
| 123 | + cls.get_field_alias("modalities"): FieldType.STRING_LIST, |
| 124 | + } |
| 125 | + |
54 | 126 | @classmethod |
55 | 127 | def features(cls): |
56 | | - return Features( |
57 | | - { |
58 | | - cls.get_field_alias("doc_id"): Value("string"), |
59 | | - cls.get_field_alias("doc_path"): Value("string"), |
60 | | - cls.get_field_alias("doc_hash"): Value("string"), |
61 | | - cls.get_field_alias("ground_truth_doc"): Value("string"), |
62 | | - cls.get_field_alias("ground_truth_segmented_pages"): Value("string"), |
63 | | - cls.get_field_alias("ground_truth_pictures"): Sequence( |
64 | | - Features_Image() |
65 | | - ), |
66 | | - cls.get_field_alias("ground_truth_page_images"): Sequence( |
67 | | - Features_Image() |
68 | | - ), |
69 | | - cls.get_field_alias("original"): Value("binary"), |
70 | | - cls.get_field_alias("mime_type"): Value("string"), |
71 | | - cls.get_field_alias("modalities"): Sequence(Value("string")), |
72 | | - } |
73 | | - ) |
| 128 | + return SchemaGenerator.generate_features(cls._get_field_definitions()) |
| 129 | + |
| 130 | + @classmethod |
| 131 | + def pyarrow_schema(cls): |
| 132 | + """Generate PyArrow schema that matches the HuggingFace datasets image format.""" |
| 133 | + return SchemaGenerator.generate_pyarrow_schema(cls._get_field_definitions()) |
74 | 134 |
|
75 | 135 | def _extract_images( |
76 | 136 | self, |
@@ -207,37 +267,31 @@ class DatasetRecordWithPrediction(DatasetRecord): |
207 | 267 |
|
208 | 268 | model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True) |
209 | 269 |
|
| 270 | + @classmethod |
| 271 | + def _get_field_definitions(cls) -> Dict[str, FieldType]: |
| 272 | + """Extend the parent schema with prediction-specific fields.""" |
| 273 | + base_definitions = super()._get_field_definitions() |
| 274 | + prediction_definitions = { |
| 275 | + cls.get_field_alias("predictor_info"): FieldType.STRING, |
| 276 | + cls.get_field_alias("status"): FieldType.STRING, |
| 277 | + cls.get_field_alias("predicted_doc"): FieldType.STRING, |
| 278 | + cls.get_field_alias("predicted_segmented_pages"): FieldType.STRING, |
| 279 | + cls.get_field_alias("predicted_pictures"): FieldType.IMAGE_LIST, |
| 280 | + cls.get_field_alias("predicted_page_images"): FieldType.IMAGE_LIST, |
| 281 | + cls.get_field_alias("prediction_format"): FieldType.STRING, |
| 282 | + cls.get_field_alias("prediction_timings"): FieldType.STRING, |
| 283 | + cls.get_field_alias("original_prediction"): FieldType.STRING, |
| 284 | + } |
| 285 | + return {**base_definitions, **prediction_definitions} |
| 286 | + |
210 | 287 | @classmethod |
211 | 288 | def features(cls): |
212 | | - return Features( |
213 | | - { |
214 | | - cls.get_field_alias("doc_id"): Value("string"), |
215 | | - cls.get_field_alias("doc_path"): Value("string"), |
216 | | - cls.get_field_alias("doc_hash"): Value("string"), |
217 | | - cls.get_field_alias("ground_truth_doc"): Value("string"), |
218 | | - cls.get_field_alias("ground_truth_segmented_pages"): Value("string"), |
219 | | - cls.get_field_alias("ground_truth_pictures"): Sequence( |
220 | | - Features_Image() |
221 | | - ), |
222 | | - cls.get_field_alias("ground_truth_page_images"): Sequence( |
223 | | - Features_Image() |
224 | | - ), |
225 | | - cls.get_field_alias("original"): Value("binary"), |
226 | | - cls.get_field_alias("mime_type"): Value("string"), |
227 | | - cls.get_field_alias("modalities"): Sequence(Value("string")), |
228 | | - cls.get_field_alias("predictor_info"): Value("string"), |
229 | | - cls.get_field_alias("status"): Value("string"), |
230 | | - cls.get_field_alias("predicted_doc"): Value("string"), |
231 | | - cls.get_field_alias("predicted_segmented_pages"): Value("string"), |
232 | | - cls.get_field_alias("predicted_pictures"): Sequence(Features_Image()), |
233 | | - cls.get_field_alias("predicted_page_images"): Sequence( |
234 | | - Features_Image() |
235 | | - ), |
236 | | - cls.get_field_alias("prediction_format"): Value("string"), |
237 | | - cls.get_field_alias("prediction_timings"): Value("string"), |
238 | | - cls.get_field_alias("original_prediction"): Value("string"), |
239 | | - } |
240 | | - ) |
| 289 | + return SchemaGenerator.generate_features(cls._get_field_definitions()) |
| 290 | + |
| 291 | + @classmethod |
| 292 | + def pyarrow_schema(cls): |
| 293 | + """Generate PyArrow schema that matches the HuggingFace datasets image format.""" |
| 294 | + return SchemaGenerator.generate_pyarrow_schema(cls._get_field_definitions()) |
241 | 295 |
|
242 | 296 | def as_record_dict(self): |
243 | 297 | record = super().as_record_dict() |
|
0 commit comments