|
1 | 1 | import os |
2 | 2 | from pathlib import Path |
3 | | -from typing import List, Optional |
4 | 3 |
|
5 | 4 | import pytest |
6 | | -from docling.datamodel.base_models import InputFormat |
7 | | -from docling.datamodel.pipeline_options import ( |
8 | | - EasyOcrOptions, |
9 | | - OcrOptions, |
10 | | - PdfPipelineOptions, |
11 | | - TableFormerMode, |
12 | | -) |
13 | | -from docling.document_converter import PdfFormatOption |
14 | | -from docling.models.factories import get_ocr_factory |
15 | 5 |
|
16 | | -from docling_eval.cli.main import evaluate, visualize |
| 6 | +from docling_eval.cli.main import ( |
| 7 | + PredictionProviderType, |
| 8 | + evaluate, |
| 9 | + get_prediction_provider, |
| 10 | + visualize, |
| 11 | +) |
17 | 12 | from docling_eval.datamodels.types import ( |
18 | 13 | BenchMarkNames, |
19 | 14 | EvaluationModality, |
|
33 | 28 | PubTabNetDatasetBuilder, |
34 | 29 | ) |
35 | 30 | from docling_eval.dataset_builders.xfund_builder import XFUNDDatasetBuilder |
36 | | -from docling_eval.prediction_providers.docling_provider import DoclingPredictionProvider |
37 | 31 | from docling_eval.prediction_providers.file_provider import FilePredictionProvider |
38 | 32 | from docling_eval.prediction_providers.tableformer_provider import ( |
39 | 33 | TableFormerPredictionProvider, |
40 | 34 | ) |
41 | 35 |
|
42 | | -ocr_factory = get_ocr_factory() |
43 | | - |
44 | 36 | IS_CI = os.getenv("RUN_IN_CI") == "1" |
45 | 37 |
|
46 | 38 |
|
47 | | -def create_docling_prediction_provider( |
48 | | - page_image_scale: float = 2.0, |
49 | | - do_ocr: bool = False, |
50 | | - ocr_lang: Optional[List[str]] = None, |
51 | | - ocr_engine: str = EasyOcrOptions.kind, |
52 | | - artifacts_path: Optional[Path] = None, |
53 | | -): |
54 | | - ocr_options: OcrOptions = ocr_factory.create_options( # type: ignore |
55 | | - kind=ocr_engine, |
56 | | - ) |
57 | | - if ocr_lang is not None: |
58 | | - ocr_options.lang = ocr_lang |
59 | | - |
60 | | - pipeline_options = PdfPipelineOptions( |
61 | | - do_ocr=do_ocr, |
62 | | - ocr_options=ocr_options, |
63 | | - do_table_structure=True, |
64 | | - artifacts_path=artifacts_path, |
65 | | - ) |
66 | | - |
67 | | - pipeline_options.table_structure_options.mode = TableFormerMode.ACCURATE |
68 | | - |
69 | | - pipeline_options.images_scale = page_image_scale |
70 | | - pipeline_options.generate_page_images = True |
71 | | - pipeline_options.generate_picture_images = True |
72 | | - |
73 | | - return DoclingPredictionProvider( |
74 | | - format_options={ |
75 | | - InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options) |
76 | | - }, |
77 | | - do_visualization=True, |
78 | | - ) |
79 | | - |
80 | | - |
81 | 39 | @pytest.mark.dependency() |
82 | 40 | def test_run_dpbench_e2e(): |
83 | 41 | target_path = Path(f"./scratch/{BenchMarkNames.DPBENCH.value}/") |
84 | | - docling_provider = create_docling_prediction_provider(page_image_scale=2.0) |
| 42 | + docling_provider = get_prediction_provider(PredictionProviderType.DOCLING) |
85 | 43 |
|
86 | 44 | dataset_layout = DPBenchDatasetBuilder( |
87 | 45 | target=target_path / "gt_dataset", |
@@ -207,7 +165,7 @@ def test_run_doclaynet_with_doctags_fileprovider(): |
207 | 165 | ) |
208 | 166 | def test_run_omnidocbench_e2e(): |
209 | 167 | target_path = Path(f"./scratch/{BenchMarkNames.OMNIDOCBENCH.value}/") |
210 | | - docling_provider = create_docling_prediction_provider(page_image_scale=2.0) |
| 168 | + docling_provider = get_prediction_provider(PredictionProviderType.DOCLING) |
211 | 169 |
|
212 | 170 | dataset_layout = OmniDocBenchDatasetBuilder( |
213 | 171 | target=target_path / "gt_dataset", |
@@ -339,7 +297,7 @@ def test_run_omnidocbench_tables(): |
339 | 297 | ) |
340 | 298 | def test_run_doclaynet_v1_e2e(): |
341 | 299 | target_path = Path(f"./scratch/{BenchMarkNames.DOCLAYNETV1.value}/") |
342 | | - docling_provider = create_docling_prediction_provider(page_image_scale=2.0) |
| 300 | + docling_provider = get_prediction_provider(PredictionProviderType.DOCLING) |
343 | 301 |
|
344 | 302 | dataset_layout = DocLayNetV1DatasetBuilder( |
345 | 303 | # prediction_provider=docling_provider, |
@@ -390,7 +348,7 @@ def test_run_doclaynet_v1_e2e(): |
390 | 348 | @pytest.mark.skip("Test needs local data which is unavailable.") |
391 | 349 | def test_run_doclaynet_v2_e2e(): |
392 | 350 | target_path = Path(f"./scratch/{BenchMarkNames.DOCLAYNETV2.value}/") |
393 | | - docling_provider = create_docling_prediction_provider(page_image_scale=2.0) |
| 351 | + docling_provider = get_prediction_provider(PredictionProviderType.DOCLING) |
394 | 352 |
|
395 | 353 | dataset_layout = DocLayNetV2DatasetBuilder( |
396 | 354 | dataset_source=Path("/path/to/doclaynet_v2_benchmark"), |
@@ -594,7 +552,7 @@ def test_run_docvqa_builder(): |
594 | 552 | ) |
595 | 553 |
|
596 | 554 | dataset_layout.save_to_disk() # does all the job of iterating the dataset, making GT+prediction records, and saving them in shards as parquet. |
597 | | - docling_provider = create_docling_prediction_provider(page_image_scale=2.0) |
| 555 | + docling_provider = get_prediction_provider(PredictionProviderType.DOCLING) |
598 | 556 |
|
599 | 557 | docling_provider.create_prediction_dataset( |
600 | 558 | name=dataset_layout.name, |
|
0 commit comments