99from datasets import Sequence , Value
1010from docling .datamodel .base_models import ConversionStatus
1111from docling_core .types import DoclingDocument
12+ from docling_core .types .doc .page import SegmentedPage
1213from docling_core .types .io import DocumentStream
13- from pydantic import BaseModel , ConfigDict , Field , model_validator
14+ from pydantic import BaseModel , ConfigDict , Field , TypeAdapter , model_validator
1415
1516from docling_eval .datamodels .types import EvaluationModality , PredictionFormats
1617
18+ seg_adapter = TypeAdapter (Dict [int , SegmentedPage ])
19+
1720
1821class DatasetRecord (
1922 BaseModel
@@ -24,6 +27,9 @@ class DatasetRecord(
2427 doc_hash : Optional [str ] = Field (alias = "document_filehash" , default = None )
2528
2629 ground_truth_doc : DoclingDocument = Field (alias = "GroundTruthDocument" )
30+ ground_truth_segmented_pages : Dict [int , SegmentedPage ] = Field (
31+ alias = "ground_truth_segmented_pages" , default = {}
32+ )
2733 original : Optional [Union [DocumentStream , Path ]] = Field (
2834 alias = "BinaryDocument" , default = None
2935 )
@@ -53,6 +59,7 @@ def features(cls):
5359 cls .get_field_alias ("doc_path" ): Value ("string" ),
5460 cls .get_field_alias ("doc_hash" ): Value ("string" ),
5561 cls .get_field_alias ("ground_truth_doc" ): Value ("string" ),
62+ cls .get_field_alias ("ground_truth_segmented_pages" ): Value ("string" ),
5663 cls .get_field_alias ("ground_truth_pictures" ): Sequence (
5764 Features_Image ()
5865 ),
@@ -102,6 +109,9 @@ def as_record_dict(self):
102109 self .ground_truth_doc .export_to_dict ()
103110 ),
104111 self .get_field_alias ("ground_truth_pictures" ): self .ground_truth_pictures ,
112+ self .get_field_alias ("ground_truth_segmented_pages" ): seg_adapter .dump_json (
113+ self .ground_truth_segmented_pages
114+ ),
105115 self .get_field_alias (
106116 "ground_truth_page_images"
107117 ): self .ground_truth_page_images ,
@@ -143,6 +153,12 @@ def validate_record_dict(cls, data: dict):
143153 if gt_doc_alias in data and isinstance (data [gt_doc_alias ], str ):
144154 data [gt_doc_alias ] = json .loads (data [gt_doc_alias ])
145155
156+ gt_seg_pages_alias = cls .get_field_alias ("ground_truth_segmented_pages" )
157+ if gt_seg_pages_alias in data and isinstance (data [gt_seg_pages_alias ], bytes ):
158+ data [gt_seg_pages_alias ] = seg_adapter .validate_json (
159+ data [gt_seg_pages_alias ]
160+ )
161+
146162 gt_page_img_alias = cls .get_field_alias ("ground_truth_page_images" )
147163 if gt_page_img_alias in data :
148164 for ix , item in enumerate (data [gt_page_img_alias ]):
@@ -171,6 +187,11 @@ class DatasetRecordWithPrediction(DatasetRecord):
171187 predicted_doc : Optional [DoclingDocument ] = Field (
172188 alias = "PredictedDocument" , default = None
173189 )
190+
191+ predicted_segmented_pages : Dict [int , SegmentedPage ] = Field (
192+ alias = "predicted_segmented_pages" , default = {}
193+ )
194+
174195 original_prediction : Optional [str ] = None
175196 prediction_format : PredictionFormats # some enum type
176197 prediction_timings : Optional [Dict ] = Field (alias = "prediction_timings" , default = None )
@@ -187,20 +208,22 @@ class DatasetRecordWithPrediction(DatasetRecord):
187208 @classmethod
188209 def features (cls ):
189210 return {
190- cls .get_field_alias ("predictor_info" ): Value ("string" ),
191- cls .get_field_alias ("status" ): Value ("string" ),
192211 cls .get_field_alias ("doc_id" ): Value ("string" ),
193212 cls .get_field_alias ("doc_path" ): Value ("string" ),
194213 cls .get_field_alias ("doc_hash" ): Value ("string" ),
195214 cls .get_field_alias ("ground_truth_doc" ): Value ("string" ),
215+ cls .get_field_alias ("ground_truth_segmented_pages" ): Value ("string" ),
196216 cls .get_field_alias ("ground_truth_pictures" ): Sequence (Features_Image ()),
197217 cls .get_field_alias ("ground_truth_page_images" ): Sequence (Features_Image ()),
198- cls .get_field_alias ("predicted_doc" ): Value ("string" ),
199- cls .get_field_alias ("predicted_pictures" ): Sequence (Features_Image ()),
200- cls .get_field_alias ("predicted_page_images" ): Sequence (Features_Image ()),
201218 cls .get_field_alias ("original" ): Value ("string" ),
202219 cls .get_field_alias ("mime_type" ): Value ("string" ),
203220 cls .get_field_alias ("modalities" ): Sequence (Value ("string" )),
221+ cls .get_field_alias ("predictor_info" ): Value ("string" ),
222+ cls .get_field_alias ("status" ): Value ("string" ),
223+ cls .get_field_alias ("predicted_doc" ): Value ("string" ),
224+ cls .get_field_alias ("predicted_segmented_pages" ): Value ("string" ),
225+ cls .get_field_alias ("predicted_pictures" ): Sequence (Features_Image ()),
226+ cls .get_field_alias ("predicted_page_images" ): Sequence (Features_Image ()),
204227 cls .get_field_alias ("prediction_format" ): Value ("string" ),
205228 cls .get_field_alias ("prediction_timings" ): Value ("string" ),
206229 }
@@ -211,6 +234,8 @@ def as_record_dict(self):
211234 {
212235 self .get_field_alias ("prediction_format" ): self .prediction_format .value ,
213236 self .get_field_alias ("prediction_timings" ): self .prediction_timings ,
237+ self .get_field_alias ("predictor_info" ): self .predictor_info ,
238+ self .get_field_alias ("status" ): (self .status ),
214239 }
215240 )
216241
@@ -220,15 +245,16 @@ def as_record_dict(self):
220245 self .get_field_alias ("predicted_doc" ): json .dumps (
221246 self .predicted_doc .export_to_dict ()
222247 ),
248+ self .get_field_alias (
249+ "predicted_segmented_pages"
250+ ): seg_adapter .dump_json (self .predicted_segmented_pages ),
223251 self .get_field_alias ("predicted_pictures" ): self .predicted_pictures ,
224252 self .get_field_alias (
225253 "predicted_page_images"
226254 ): self .predicted_page_images ,
227255 self .get_field_alias ("original_prediction" ): (
228256 self .original_prediction
229257 ),
230- self .get_field_alias ("status" ): (self .status ),
231- self .get_field_alias ("predictor_info" ): self .predictor_info ,
232258 }
233259 )
234260
@@ -262,6 +288,14 @@ def validate_prediction_record_dict(cls, data: dict):
262288 if pred_doc_alias in data and isinstance (data [pred_doc_alias ], str ):
263289 data [pred_doc_alias ] = json .loads (data [pred_doc_alias ])
264290
291+ pred_seg_pages_alias = cls .get_field_alias ("predicted_segmented_pages" )
292+ if pred_seg_pages_alias in data and isinstance (
293+ data [pred_seg_pages_alias ], bytes
294+ ):
295+ data [pred_seg_pages_alias ] = seg_adapter .validate_json (
296+ data [pred_seg_pages_alias ]
297+ )
298+
265299 pred_page_img_alias = cls .get_field_alias ("predicted_page_images" )
266300 if pred_page_img_alias in data :
267301 for ix , item in enumerate (data [pred_page_img_alias ]):
0 commit comments