|
| 1 | +import io |
| 2 | +import logging |
| 3 | +from pathlib import Path |
| 4 | +from typing import Iterable, List, Optional, Set |
| 5 | + |
| 6 | +import PIL.Image |
| 7 | +from datasets import load_dataset |
| 8 | +from docling_core.types import DoclingDocument |
| 9 | +from docling_core.types.doc import ( |
| 10 | + BoundingBox, |
| 11 | + CoordOrigin, |
| 12 | + DocItemLabel, |
| 13 | + GroupItem, |
| 14 | + GroupLabel, |
| 15 | + ImageRef, |
| 16 | + PageItem, |
| 17 | + ProvenanceItem, |
| 18 | + Size, |
| 19 | + TableCell, |
| 20 | + TableData, |
| 21 | +) |
| 22 | +from docling_core.types.io import DocumentStream |
| 23 | +from tqdm import tqdm |
| 24 | + |
| 25 | +from docling_eval.datamodels.dataset_record import DatasetRecord |
| 26 | +from docling_eval.datamodels.types import BenchMarkColumns, EvaluationModality |
| 27 | +from docling_eval.dataset_builders.dataset_builder import ( |
| 28 | + BaseEvaluationDatasetBuilder, |
| 29 | + HFSource, |
| 30 | +) |
| 31 | +from docling_eval.utils.utils import ( |
| 32 | + add_pages_to_true_doc, |
| 33 | + crop_bounding_box, |
| 34 | + extract_images, |
| 35 | + from_pil_to_base64uri, |
| 36 | + get_binhash, |
| 37 | +) |
| 38 | + |
| 39 | +# Get logger |
| 40 | +_log = logging.getLogger(__name__) |
| 41 | + |
| 42 | + |
| 43 | +class DocVQADatasetBuilder(BaseEvaluationDatasetBuilder): |
| 44 | + """ |
| 45 | + DocVQA dataset builder implementing the base dataset builder interface. |
| 46 | +
|
| 47 | + This builder processes the DocVQA dataset, which contains document |
| 48 | + layout annotations for a variety of document types. |
| 49 | + """ |
| 50 | + |
| 51 | + def __init__( |
| 52 | + self, |
| 53 | + target: Path, |
| 54 | + split: str = "test", |
| 55 | + begin_index: int = 0, |
| 56 | + end_index: int = -1, |
| 57 | + ): |
| 58 | + """ |
| 59 | + Initialize the DocVQA dataset builder. |
| 60 | +
|
| 61 | + Args: |
| 62 | + target: Path where processed dataset will be saved |
| 63 | + split: Dataset split to use |
| 64 | + begin_index: Start index for processing (inclusive) |
| 65 | + end_index: End index for processing (exclusive), -1 means process all |
| 66 | + """ |
| 67 | + super().__init__( |
| 68 | + name="DocVQA", |
| 69 | + dataset_source=HFSource(repo_id="lmms-lab/DocVQA"), |
| 70 | + target=target, |
| 71 | + split=split, |
| 72 | + begin_index=begin_index, |
| 73 | + end_index=end_index, |
| 74 | + ) |
| 75 | + |
| 76 | + def _process_document(self, doc_id, qa_items) -> DatasetRecord: |
| 77 | + """Process all QA items for a single document.""" |
| 78 | + _log.debug(f"Processing document: {doc_id}") |
| 79 | + |
| 80 | + doc = DoclingDocument(name=f"{doc_id}") |
| 81 | + image: PIL.Image.Image = qa_items[0]["image"] |
| 82 | + image = image.convert("RGB") |
| 83 | + image_ref = ImageRef( |
| 84 | + mimetype="image/png", |
| 85 | + dpi=72, |
| 86 | + size=Size(width=image.width, height=image.height), |
| 87 | + uri=from_pil_to_base64uri(image), |
| 88 | + ) |
| 89 | + page_item = PageItem( |
| 90 | + page_no=1, |
| 91 | + size=Size(width=float(image.width), height=float(image.height)), |
| 92 | + image=image_ref, |
| 93 | + ) |
| 94 | + |
| 95 | + doc.pages[1] = page_item |
| 96 | + for qa_item in qa_items: |
| 97 | + _log.debug(f" Processing QA item data...") |
| 98 | + |
| 99 | + # Extract images from the ground truth document |
| 100 | + doc, true_pictures, true_page_images = extract_images( |
| 101 | + document=doc, |
| 102 | + pictures_column=BenchMarkColumns.GROUNDTRUTH_PICTURES.value, |
| 103 | + page_images_column=BenchMarkColumns.GROUNDTRUTH_PAGE_IMAGES.value, |
| 104 | + ) |
| 105 | + |
| 106 | + # Convert image to bytes for storage |
| 107 | + with io.BytesIO() as img_byte_stream: |
| 108 | + image.save(img_byte_stream, format="PNG") |
| 109 | + img_byte_stream.seek(0) |
| 110 | + img_bytes = img_byte_stream.getvalue() |
| 111 | + |
| 112 | + # Create dataset record |
| 113 | + record = DatasetRecord( |
| 114 | + doc_id=str(doc_id), |
| 115 | + doc_hash=get_binhash(img_bytes), |
| 116 | + ground_truth_doc=doc, |
| 117 | + original=DocumentStream(name=str(doc_id), stream=io.BytesIO(img_bytes)), |
| 118 | + mime_type="image/png", |
| 119 | + modalities=[ |
| 120 | + EvaluationModality.LAYOUT, |
| 121 | + EvaluationModality.QUESTION_ANSWERING, |
| 122 | + ], |
| 123 | + ground_truth_pictures=true_pictures, |
| 124 | + ground_truth_page_images=true_page_images, |
| 125 | + ) |
| 126 | + |
| 127 | + return record |
| 128 | + |
| 129 | + def iterate(self) -> Iterable[DatasetRecord]: |
| 130 | + """ |
| 131 | + Iterate through the dataset and yield DatasetRecord objects. |
| 132 | +
|
| 133 | + Yields: |
| 134 | + DatasetRecord objects |
| 135 | + """ |
| 136 | + assert isinstance(self.dataset_source, HFSource) |
| 137 | + |
| 138 | + path = self.dataset_source.repo_id |
| 139 | + if self.dataset_local_path is not None: |
| 140 | + path = str(self.dataset_local_path) |
| 141 | + # Load dataset from the retrieved path |
| 142 | + ds = load_dataset(path, split=self.split, name="DocVQA") |
| 143 | + |
| 144 | + # Apply HuggingFace's select method for index ranges |
| 145 | + total_ds_len = len(ds) |
| 146 | + begin, end = self.get_effective_indices(total_ds_len) |
| 147 | + |
| 148 | + # Select the range (HuggingFace datasets have a convenient select method) |
| 149 | + ds = ds.select(range(begin, end)) |
| 150 | + selected_ds_len = len(ds) |
| 151 | + |
| 152 | + # Log stats |
| 153 | + self.log_dataset_stats(total_ds_len, selected_ds_len) |
| 154 | + |
| 155 | + skipped_rows = 0 |
| 156 | + exported_rows = 0 |
| 157 | + |
| 158 | + sorted_dataset = ds.sort("docId") |
| 159 | + |
| 160 | + # Initialize variables |
| 161 | + current_doc_id = None |
| 162 | + current_doc_qa_items = [] # type: ignore |
| 163 | + |
| 164 | + # Iterate through the sorted dataset |
| 165 | + for sample in tqdm( |
| 166 | + sorted_dataset, |
| 167 | + total=selected_ds_len, |
| 168 | + ncols=128, |
| 169 | + desc="Processing DocVQA records...", |
| 170 | + ): |
| 171 | + # Check if we've moved to a new docId |
| 172 | + if sample["docId"] != current_doc_id: |
| 173 | + # Process the previous doc's QA items (skip first iteration) |
| 174 | + if current_doc_qa_items: |
| 175 | + rec = self._process_document(current_doc_id, current_doc_qa_items) |
| 176 | + yield rec |
| 177 | + exported_rows += 1 |
| 178 | + |
| 179 | + # Start a new document group |
| 180 | + current_doc_id = sample["docId"] |
| 181 | + current_doc_qa_items = [sample] |
| 182 | + else: |
| 183 | + current_doc_qa_items.append(sample) |
| 184 | + |
| 185 | + # Process the final document group |
| 186 | + if current_doc_qa_items: |
| 187 | + rec = self._process_document(current_doc_id, current_doc_qa_items) |
| 188 | + yield rec |
| 189 | + exported_rows += 1 |
| 190 | + |
| 191 | + _log.info( |
| 192 | + "Exported rows: %s. Skipped rows: %s.", |
| 193 | + exported_rows, |
| 194 | + skipped_rows, |
| 195 | + ) |
0 commit comments