Skip to content

Commit 6efba9b

Browse files
authored
feat: Add SmolDocling provider and CLI option, cleanup test unit (#56)
Signed-off-by: Christoph Auer <[email protected]>
1 parent 09c40f1 commit 6efba9b

File tree

2 files changed

+54
-56
lines changed

2 files changed

+54
-56
lines changed

docling_eval/cli/main.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,23 @@
11
import json
22
import logging
33
import os
4+
import sys
45
from enum import Enum
56
from pathlib import Path
6-
from typing import Annotated, Optional, Tuple
7+
from typing import Annotated, Dict, Optional, Tuple
78

89
import typer
910
from docling.datamodel.base_models import InputFormat
10-
from docling.datamodel.pipeline_options import PdfPipelineOptions
11-
from docling.document_converter import PdfFormatOption
11+
from docling.datamodel.pipeline_options import (
12+
PaginatedPipelineOptions,
13+
PdfPipelineOptions,
14+
VlmPipelineOptions,
15+
smoldocling_vlm_conversion_options,
16+
smoldocling_vlm_mlx_conversion_options,
17+
)
18+
from docling.document_converter import FormatOption, PdfFormatOption
1219
from docling.models.factories import get_ocr_factory
20+
from docling.pipeline.vlm_pipeline import VlmPipeline
1321
from tabulate import tabulate # type: ignore
1422

1523
from docling_eval.datamodels.types import (
@@ -74,6 +82,7 @@ class PredictionProviderType(str, Enum):
7482
DOCLING = "docling"
7583
TABLEFORMER = "tableformer"
7684
FILE = "file"
85+
SMOLDOCLING = "smoldocling"
7786

7887

7988
def log_and_save_stats(
@@ -184,6 +193,7 @@ def get_prediction_provider(
184193
file_source_path: Optional[Path] = None,
185194
file_prediction_format: Optional[PredictionFormats] = None,
186195
):
196+
pipeline_options: PaginatedPipelineOptions
187197
"""Get the appropriate prediction provider with default settings."""
188198
if provider_type == PredictionProviderType.DOCLING:
189199
ocr_factory = get_ocr_factory()
@@ -211,6 +221,36 @@ def get_prediction_provider(
211221
ignore_missing_predictions=True,
212222
)
213223

224+
elif provider_type == PredictionProviderType.SMOLDOCLING:
225+
pipeline_options = VlmPipelineOptions()
226+
227+
pipeline_options.vlm_options = smoldocling_vlm_conversion_options
228+
if sys.platform == "darwin":
229+
try:
230+
import mlx_vlm # type: ignore
231+
232+
pipeline_options.vlm_options = smoldocling_vlm_mlx_conversion_options
233+
except ImportError:
234+
_log.warning(
235+
"To run SmolDocling faster, please install mlx-vlm:\n"
236+
"pip install mlx-vlm"
237+
)
238+
239+
pdf_format_option = PdfFormatOption(
240+
pipeline_cls=VlmPipeline, pipeline_options=pipeline_options
241+
)
242+
243+
format_options: Dict[InputFormat, FormatOption] = {
244+
InputFormat.PDF: pdf_format_option,
245+
InputFormat.IMAGE: pdf_format_option,
246+
}
247+
248+
return DoclingPredictionProvider(
249+
format_options=format_options,
250+
do_visualization=True,
251+
ignore_missing_predictions=True,
252+
)
253+
214254
elif provider_type == PredictionProviderType.TABLEFORMER:
215255
return TableFormerPredictionProvider(
216256
do_visualization=True,

tests/test_dataset_builder.py

Lines changed: 11 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,14 @@
11
import os
22
from pathlib import Path
3-
from typing import List, Optional
43

54
import pytest
6-
from docling.datamodel.base_models import InputFormat
7-
from docling.datamodel.pipeline_options import (
8-
EasyOcrOptions,
9-
OcrOptions,
10-
PdfPipelineOptions,
11-
TableFormerMode,
12-
)
13-
from docling.document_converter import PdfFormatOption
14-
from docling.models.factories import get_ocr_factory
155

16-
from docling_eval.cli.main import evaluate, visualize
6+
from docling_eval.cli.main import (
7+
PredictionProviderType,
8+
evaluate,
9+
get_prediction_provider,
10+
visualize,
11+
)
1712
from docling_eval.datamodels.types import (
1813
BenchMarkNames,
1914
EvaluationModality,
@@ -33,55 +28,18 @@
3328
PubTabNetDatasetBuilder,
3429
)
3530
from docling_eval.dataset_builders.xfund_builder import XFUNDDatasetBuilder
36-
from docling_eval.prediction_providers.docling_provider import DoclingPredictionProvider
3731
from docling_eval.prediction_providers.file_provider import FilePredictionProvider
3832
from docling_eval.prediction_providers.tableformer_provider import (
3933
TableFormerPredictionProvider,
4034
)
4135

42-
ocr_factory = get_ocr_factory()
43-
4436
IS_CI = os.getenv("RUN_IN_CI") == "1"
4537

4638

47-
def create_docling_prediction_provider(
48-
page_image_scale: float = 2.0,
49-
do_ocr: bool = False,
50-
ocr_lang: Optional[List[str]] = None,
51-
ocr_engine: str = EasyOcrOptions.kind,
52-
artifacts_path: Optional[Path] = None,
53-
):
54-
ocr_options: OcrOptions = ocr_factory.create_options( # type: ignore
55-
kind=ocr_engine,
56-
)
57-
if ocr_lang is not None:
58-
ocr_options.lang = ocr_lang
59-
60-
pipeline_options = PdfPipelineOptions(
61-
do_ocr=do_ocr,
62-
ocr_options=ocr_options,
63-
do_table_structure=True,
64-
artifacts_path=artifacts_path,
65-
)
66-
67-
pipeline_options.table_structure_options.mode = TableFormerMode.ACCURATE
68-
69-
pipeline_options.images_scale = page_image_scale
70-
pipeline_options.generate_page_images = True
71-
pipeline_options.generate_picture_images = True
72-
73-
return DoclingPredictionProvider(
74-
format_options={
75-
InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options)
76-
},
77-
do_visualization=True,
78-
)
79-
80-
8139
@pytest.mark.dependency()
8240
def test_run_dpbench_e2e():
8341
target_path = Path(f"./scratch/{BenchMarkNames.DPBENCH.value}/")
84-
docling_provider = create_docling_prediction_provider(page_image_scale=2.0)
42+
docling_provider = get_prediction_provider(PredictionProviderType.DOCLING)
8543

8644
dataset_layout = DPBenchDatasetBuilder(
8745
target=target_path / "gt_dataset",
@@ -207,7 +165,7 @@ def test_run_doclaynet_with_doctags_fileprovider():
207165
)
208166
def test_run_omnidocbench_e2e():
209167
target_path = Path(f"./scratch/{BenchMarkNames.OMNIDOCBENCH.value}/")
210-
docling_provider = create_docling_prediction_provider(page_image_scale=2.0)
168+
docling_provider = get_prediction_provider(PredictionProviderType.DOCLING)
211169

212170
dataset_layout = OmniDocBenchDatasetBuilder(
213171
target=target_path / "gt_dataset",
@@ -339,7 +297,7 @@ def test_run_omnidocbench_tables():
339297
)
340298
def test_run_doclaynet_v1_e2e():
341299
target_path = Path(f"./scratch/{BenchMarkNames.DOCLAYNETV1.value}/")
342-
docling_provider = create_docling_prediction_provider(page_image_scale=2.0)
300+
docling_provider = get_prediction_provider(PredictionProviderType.DOCLING)
343301

344302
dataset_layout = DocLayNetV1DatasetBuilder(
345303
# prediction_provider=docling_provider,
@@ -390,7 +348,7 @@ def test_run_doclaynet_v1_e2e():
390348
@pytest.mark.skip("Test needs local data which is unavailable.")
391349
def test_run_doclaynet_v2_e2e():
392350
target_path = Path(f"./scratch/{BenchMarkNames.DOCLAYNETV2.value}/")
393-
docling_provider = create_docling_prediction_provider(page_image_scale=2.0)
351+
docling_provider = get_prediction_provider(PredictionProviderType.DOCLING)
394352

395353
dataset_layout = DocLayNetV2DatasetBuilder(
396354
dataset_source=Path("/path/to/doclaynet_v2_benchmark"),
@@ -594,7 +552,7 @@ def test_run_docvqa_builder():
594552
)
595553

596554
dataset_layout.save_to_disk() # does all the job of iterating the dataset, making GT+prediction records, and saving them in shards as parquet.
597-
docling_provider = create_docling_prediction_provider(page_image_scale=2.0)
555+
docling_provider = get_prediction_provider(PredictionProviderType.DOCLING)
598556

599557
docling_provider.create_prediction_dataset(
600558
name=dataset_layout.name,

0 commit comments

Comments
 (0)