Skip to content

Commit c08950b

Browse files
authored
perf: Improve parquet writing with plain pyarrow (#134)
* perf: Improve parquet writing with plain pyarrow Signed-off-by: Christoph Auer <[email protected]> * Smaller fixes Signed-off-by: Christoph Auer <[email protected]> * Add pyarrow dep Signed-off-by: Christoph Auer <[email protected]> * Fix circular import Signed-off-by: Christoph Auer <[email protected]> --------- Signed-off-by: Christoph Auer <[email protected]>
1 parent a34f264 commit c08950b

File tree

10 files changed

+189
-510
lines changed

10 files changed

+189
-510
lines changed

docling_eval/cli/main.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -910,13 +910,20 @@ def create_cvat(
910910
gt_dir: Annotated[Path, typer.Option(help="Dataset source path")],
911911
bucket_size: Annotated[int, typer.Option(help="Size of CVAT tasks")] = 20,
912912
use_predictions: Annotated[bool, typer.Option(help="use predictions")] = False,
913+
sliding_window: Annotated[
914+
int,
915+
typer.Option(
916+
help="Size of sliding window for page processing (1 for single pages, >1 for multi-page windows)"
917+
),
918+
] = 2,
913919
):
914920
"""Create dataset ready to upload to CVAT starting from (ground-truth) dataset."""
915921
builder = CvatPreannotationBuilder(
916922
dataset_source=gt_dir,
917923
target=output_dir,
918924
bucket_size=bucket_size,
919925
use_predictions=use_predictions,
926+
sliding_window=sliding_window,
920927
)
921928
builder.prepare_for_annotation()
922929

docling_eval/datamodels/dataset_record.py

Lines changed: 101 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
from enum import Enum
23
from io import BytesIO
34
from pathlib import Path
45
from typing import Dict, List, Optional, Union
@@ -19,6 +20,61 @@
1920
seg_adapter = TypeAdapter(Dict[int, SegmentedPage])
2021

2122

23+
class FieldType(Enum):
24+
STRING = "string"
25+
BINARY = "binary"
26+
IMAGE_LIST = "image_list"
27+
STRING_LIST = "string_list"
28+
29+
30+
class SchemaGenerator:
31+
"""Generates both HuggingFace Features and PyArrow schemas from a field definition."""
32+
33+
@staticmethod
34+
def _get_features_type(field_type: FieldType):
35+
mapping = {
36+
FieldType.STRING: Value("string"),
37+
FieldType.BINARY: Value("binary"),
38+
FieldType.IMAGE_LIST: Sequence(Features_Image()),
39+
FieldType.STRING_LIST: Sequence(Value("string")),
40+
}
41+
return mapping[field_type]
42+
43+
@staticmethod
44+
def _get_pyarrow_type(field_type: FieldType):
45+
import pyarrow as pa
46+
47+
image_type = pa.struct([("bytes", pa.binary()), ("path", pa.string())])
48+
49+
mapping = {
50+
FieldType.STRING: pa.string(),
51+
FieldType.BINARY: pa.binary(),
52+
FieldType.IMAGE_LIST: pa.list_(image_type),
53+
FieldType.STRING_LIST: pa.list_(pa.string()),
54+
}
55+
return mapping[field_type]
56+
57+
@classmethod
58+
def generate_features(cls, field_definitions: Dict[str, FieldType]) -> Features:
59+
return Features(
60+
{
61+
field_name: cls._get_features_type(field_type)
62+
for field_name, field_type in field_definitions.items()
63+
}
64+
)
65+
66+
@classmethod
67+
def generate_pyarrow_schema(cls, field_definitions: Dict[str, FieldType]):
68+
import pyarrow as pa
69+
70+
return pa.schema(
71+
[
72+
(field_name, cls._get_pyarrow_type(field_type))
73+
for field_name, field_type in field_definitions.items()
74+
]
75+
)
76+
77+
2278
class DatasetRecord(
2379
BaseModel
2480
): # TODO make predictionrecord class, factor prediction-related fields there.
@@ -51,26 +107,30 @@ class DatasetRecord(
51107
def get_field_alias(cls, field_name: str) -> str:
52108
return cls.model_fields[field_name].alias or field_name
53109

110+
@classmethod
111+
def _get_field_definitions(cls) -> Dict[str, FieldType]:
112+
"""Define the schema for this class. Override in subclasses to extend."""
113+
return {
114+
cls.get_field_alias("doc_id"): FieldType.STRING,
115+
cls.get_field_alias("doc_path"): FieldType.STRING,
116+
cls.get_field_alias("doc_hash"): FieldType.STRING,
117+
cls.get_field_alias("ground_truth_doc"): FieldType.STRING,
118+
cls.get_field_alias("ground_truth_segmented_pages"): FieldType.STRING,
119+
cls.get_field_alias("ground_truth_pictures"): FieldType.IMAGE_LIST,
120+
cls.get_field_alias("ground_truth_page_images"): FieldType.IMAGE_LIST,
121+
cls.get_field_alias("original"): FieldType.BINARY,
122+
cls.get_field_alias("mime_type"): FieldType.STRING,
123+
cls.get_field_alias("modalities"): FieldType.STRING_LIST,
124+
}
125+
54126
@classmethod
55127
def features(cls):
56-
return Features(
57-
{
58-
cls.get_field_alias("doc_id"): Value("string"),
59-
cls.get_field_alias("doc_path"): Value("string"),
60-
cls.get_field_alias("doc_hash"): Value("string"),
61-
cls.get_field_alias("ground_truth_doc"): Value("string"),
62-
cls.get_field_alias("ground_truth_segmented_pages"): Value("string"),
63-
cls.get_field_alias("ground_truth_pictures"): Sequence(
64-
Features_Image()
65-
),
66-
cls.get_field_alias("ground_truth_page_images"): Sequence(
67-
Features_Image()
68-
),
69-
cls.get_field_alias("original"): Value("binary"),
70-
cls.get_field_alias("mime_type"): Value("string"),
71-
cls.get_field_alias("modalities"): Sequence(Value("string")),
72-
}
73-
)
128+
return SchemaGenerator.generate_features(cls._get_field_definitions())
129+
130+
@classmethod
131+
def pyarrow_schema(cls):
132+
"""Generate PyArrow schema that matches the HuggingFace datasets image format."""
133+
return SchemaGenerator.generate_pyarrow_schema(cls._get_field_definitions())
74134

75135
def _extract_images(
76136
self,
@@ -207,37 +267,31 @@ class DatasetRecordWithPrediction(DatasetRecord):
207267

208268
model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
209269

270+
@classmethod
271+
def _get_field_definitions(cls) -> Dict[str, FieldType]:
272+
"""Extend the parent schema with prediction-specific fields."""
273+
base_definitions = super()._get_field_definitions()
274+
prediction_definitions = {
275+
cls.get_field_alias("predictor_info"): FieldType.STRING,
276+
cls.get_field_alias("status"): FieldType.STRING,
277+
cls.get_field_alias("predicted_doc"): FieldType.STRING,
278+
cls.get_field_alias("predicted_segmented_pages"): FieldType.STRING,
279+
cls.get_field_alias("predicted_pictures"): FieldType.IMAGE_LIST,
280+
cls.get_field_alias("predicted_page_images"): FieldType.IMAGE_LIST,
281+
cls.get_field_alias("prediction_format"): FieldType.STRING,
282+
cls.get_field_alias("prediction_timings"): FieldType.STRING,
283+
cls.get_field_alias("original_prediction"): FieldType.STRING,
284+
}
285+
return {**base_definitions, **prediction_definitions}
286+
210287
@classmethod
211288
def features(cls):
212-
return Features(
213-
{
214-
cls.get_field_alias("doc_id"): Value("string"),
215-
cls.get_field_alias("doc_path"): Value("string"),
216-
cls.get_field_alias("doc_hash"): Value("string"),
217-
cls.get_field_alias("ground_truth_doc"): Value("string"),
218-
cls.get_field_alias("ground_truth_segmented_pages"): Value("string"),
219-
cls.get_field_alias("ground_truth_pictures"): Sequence(
220-
Features_Image()
221-
),
222-
cls.get_field_alias("ground_truth_page_images"): Sequence(
223-
Features_Image()
224-
),
225-
cls.get_field_alias("original"): Value("binary"),
226-
cls.get_field_alias("mime_type"): Value("string"),
227-
cls.get_field_alias("modalities"): Sequence(Value("string")),
228-
cls.get_field_alias("predictor_info"): Value("string"),
229-
cls.get_field_alias("status"): Value("string"),
230-
cls.get_field_alias("predicted_doc"): Value("string"),
231-
cls.get_field_alias("predicted_segmented_pages"): Value("string"),
232-
cls.get_field_alias("predicted_pictures"): Sequence(Features_Image()),
233-
cls.get_field_alias("predicted_page_images"): Sequence(
234-
Features_Image()
235-
),
236-
cls.get_field_alias("prediction_format"): Value("string"),
237-
cls.get_field_alias("prediction_timings"): Value("string"),
238-
cls.get_field_alias("original_prediction"): Value("string"),
239-
}
240-
)
289+
return SchemaGenerator.generate_features(cls._get_field_definitions())
290+
291+
@classmethod
292+
def pyarrow_schema(cls):
293+
"""Generate PyArrow schema that matches the HuggingFace datasets image format."""
294+
return SchemaGenerator.generate_pyarrow_schema(cls._get_field_definitions())
241295

242296
def as_record_dict(self):
243297
record = super().as_record_dict()

docling_eval/dataset_builders/cvat_preannotation_builder.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(
4747
target: Path,
4848
bucket_size: int = 200,
4949
use_predictions: bool = False,
50+
sliding_window: int = 2,
5051
):
5152
"""
5253
Initialize the CvatPreannotationBuilder.
@@ -55,10 +56,13 @@ def __init__(
5556
dataset_source: Directory containing the source dataset
5657
target: Directory where CVAT preannotations will be saved
5758
bucket_size: Number of documents per bucket for CVAT tasks
59+
use_predictions: Whether to use predictions instead of ground truth
60+
sliding_window: Size of sliding window for page processing (1 for single pages, >1 for multi-page windows)
5861
"""
5962
self.source_dir = dataset_source
6063
self.target_dir = target
6164
self.bucket_size = bucket_size
65+
self.sliding_window = sliding_window
6266
self.benchmark_dirs = BenchMarkDirs()
6367
self.benchmark_dirs.set_up_directory_structure(
6468
source=dataset_source, target=target
@@ -799,7 +803,7 @@ def prepare_for_annotation(self) -> None:
799803
_log.info(f"Preparing dataset from {self.source_dir} for CVAT annotation")
800804
self._create_project_properties()
801805
self.overview = self._export_from_dataset()
802-
self._create_preannotation_files(sliding_window=1)
806+
self._create_preannotation_files(sliding_window=self.sliding_window)
803807
_log.info(f"CVAT annotation preparation complete in {self.target_dir}")
804808

805809

docling_eval/dataset_builders/dataset_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,8 +305,8 @@ def save_to_disk(
305305
save_shard_to_disk(
306306
items=record_list,
307307
dataset_path=test_dir,
308+
schema=DatasetRecord.pyarrow_schema(),
308309
shard_id=chunk_count,
309-
features=DatasetRecord.features(),
310310
)
311311
count += len(record_list)
312312
chunk_count += 1

docling_eval/prediction_providers/base_prediction_provider.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,8 +398,8 @@ def _iterate_predictions() -> Iterable[DatasetRecordWithPrediction]:
398398
save_shard_to_disk(
399399
items=record_chunk,
400400
dataset_path=test_dir,
401+
schema=DatasetRecordWithPrediction.pyarrow_schema(),
401402
shard_id=chunk_count,
402-
features=DatasetRecordWithPrediction.features(),
403403
)
404404
count += len(record_chunk)
405405
chunk_count += 1

docling_eval/utils/utils.py

Lines changed: 70 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77
from importlib.metadata import PackageNotFoundError, version
88
from io import BytesIO
99
from pathlib import Path
10-
from typing import Any, Dict, List, Optional, Tuple, Union
10+
from typing import Any, Dict, List, Optional, Tuple, Type, Union
1111

1212
import pandas as pd
1313
import PIL.Image
1414
from bs4 import BeautifulSoup # type: ignore
1515
from datasets import Dataset, Features, load_dataset
1616
from datasets.iterable_dataset import IterableDataset
1717
from docling.backend.docling_parse_v4_backend import DoclingParseV4DocumentBackend
18+
from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend
1819
from docling.datamodel.base_models import InputFormat, Page
1920
from docling.datamodel.document import InputDocument
2021
from docling_core.types.doc.base import BoundingBox, CoordOrigin, Size
@@ -78,12 +79,14 @@ def write_datasets_info(
7879
json.dump(dataset_infos, fw, indent=2)
7980

8081

81-
def get_input_document(file: Path | BytesIO) -> InputDocument:
82+
def get_input_document(
83+
file: Path | BytesIO, backend_t: Type[Any] = DoclingParseV4DocumentBackend
84+
) -> InputDocument:
8285
return InputDocument(
8386
path_or_stream=file,
8487
format=InputFormat.PDF, # type: ignore[arg-type]
8588
filename=file.name if isinstance(file, Path) else "foo",
86-
backend=DoclingParseV4DocumentBackend,
89+
backend=backend_t,
8790
)
8891

8992

@@ -97,7 +100,7 @@ def from_pil_to_base64uri(img: Image.Image) -> AnyUrl:
97100
def add_pages_to_true_doc(
98101
pdf_path: Path | BytesIO, true_doc: DoclingDocument, image_scale: float = 1.0
99102
):
100-
in_doc = get_input_document(pdf_path)
103+
in_doc = get_input_document(pdf_path, backend_t=PyPdfiumDocumentBackend)
101104
assert in_doc.valid, "Input doc must be valid."
102105
# assert in_doc.page_count == 1, "doc must have one page."
103106

@@ -106,7 +109,11 @@ def add_pages_to_true_doc(
106109

107110
for page_no in range(0, in_doc.page_count):
108111
page = Page(page_no=page_no)
109-
page._backend = in_doc._backend.load_page(page.page_no) # type: ignore[attr-defined]
112+
try:
113+
page._backend = in_doc._backend.load_page(page.page_no) # type: ignore[attr-defined]
114+
except RuntimeError as e:
115+
logging.warning(f"Failed to load page {page.page_no}: {e}")
116+
page._backend = None
110117

111118
if page._backend is not None and page._backend.is_valid():
112119
page.size = page._backend.get_size()
@@ -489,34 +496,78 @@ def insert_images(
489496
return document
490497

491498

499+
def _pil_to_bytes(img: PIL.Image.Image) -> bytes:
500+
"""Convert PIL image to PNG bytes efficiently."""
501+
buffered = io.BytesIO()
502+
img.save(buffered, format="PNG")
503+
return buffered.getvalue()
504+
505+
492506
def save_shard_to_disk(
493507
items: List[Any],
494508
dataset_path: Path,
509+
schema: Any,
495510
thread_id: int = 0,
496511
shard_id: int = 0,
497-
features: Optional[Features] = None,
498-
shard_format: str = "parquet",
499512
) -> None:
500-
"""Save shard to disk."""
513+
"""Save shard to disk as parquet."""
501514
if not items:
502515
return
503516

504-
# Use features if provided to avoid schema inference
505-
batch = Dataset.from_list(items, features=features)
506-
507-
output_file = dataset_path / f"shard_{thread_id:06}_{shard_id:06}.{shard_format}"
508-
if shard_format == "json":
509-
batch.to_json(output_file)
510-
elif shard_format == "parquet":
511-
batch.to_parquet(output_file)
512-
else:
513-
raise ValueError(f"Unsupported shard_format: {shard_format}")
517+
# Write directly to parquet using pyarrow to avoid Dataset.from_list() overhead
518+
_save_to_parquet_direct(items, dataset_path, thread_id, shard_id, schema)
514519

515-
logging.info(f"Saved shard {shard_id} to {output_file} with {len(items)} documents")
520+
logging.info(
521+
f"Saved shard {shard_id} to {dataset_path / f'shard_{thread_id:06}_{shard_id:06}.parquet'} with {len(items)} documents"
522+
)
516523

517524
shard_id += 1
518525

519526

527+
def _save_to_parquet_direct(
528+
items: List[Any], dataset_path: Path, thread_id: int, shard_id: int, schema: Any
529+
) -> None:
530+
"""Save directly to parquet using pyarrow to avoid Dataset.from_list() overhead."""
531+
import pyarrow as pa
532+
import pyarrow.parquet as pq
533+
534+
# Import here to avoid circular import
535+
from docling_eval.datamodels.dataset_record import DatasetRecordWithPrediction
536+
537+
# Convert data to pyarrow table format
538+
records = []
539+
for item in items:
540+
record = dict(item)
541+
542+
# Convert PIL images to bytes for direct Arrow storage
543+
for field_name in [
544+
DatasetRecordWithPrediction.get_field_alias("ground_truth_pictures"),
545+
DatasetRecordWithPrediction.get_field_alias("ground_truth_page_images"),
546+
DatasetRecordWithPrediction.get_field_alias("predicted_pictures"),
547+
DatasetRecordWithPrediction.get_field_alias("predicted_page_images"),
548+
]:
549+
if field_name in record:
550+
images = record[field_name]
551+
if (
552+
images
553+
and len(images) > 0
554+
and isinstance(images[0], PIL.Image.Image)
555+
):
556+
# Convert to the same format as HuggingFace datasets expects
557+
record[field_name] = [
558+
{"bytes": _pil_to_bytes(img), "path": None} for img in images
559+
]
560+
561+
records.append(record)
562+
563+
# Create pyarrow table with mandatory explicit schema
564+
table = pa.Table.from_pylist(records, schema=schema)
565+
566+
# Write to parquet
567+
output_file = dataset_path / f"shard_{thread_id:06}_{shard_id:06}.parquet"
568+
pq.write_table(table, output_file)
569+
570+
520571
def dataset_exists(
521572
ds_path: Path,
522573
split: str,

0 commit comments

Comments
 (0)