diff --git a/docling/datamodel/pipeline_options.py b/docling/datamodel/pipeline_options.py index 11448437a..3586b1b31 100644 --- a/docling/datamodel/pipeline_options.py +++ b/docling/datamodel/pipeline_options.py @@ -283,6 +283,9 @@ class PipelineOptions(BaseOptions): class ConvertPipelineOptions(PipelineOptions): """Base convert pipeline options.""" + do_ocr: bool = True # True: perform OCR, replace programmatic PDF text + ocr_options: OcrOptions = OcrAutoOptions() + do_picture_classification: bool = False # True: classify pictures in documents do_picture_description: bool = False # True: run describe pictures in documents @@ -335,7 +338,6 @@ class PdfPipelineOptions(PaginatedPipelineOptions): """Options for the PDF pipeline.""" do_table_structure: bool = True # True: perform table structure extraction - do_ocr: bool = True # True: perform OCR, replace programmatic PDF text do_code_enrichment: bool = False # True: perform code OCR do_formula_enrichment: bool = False # True: perform formula OCR, return Latex code force_backend_text: bool = ( @@ -344,7 +346,6 @@ class PdfPipelineOptions(PaginatedPipelineOptions): # If True, text from backend will be used instead of generated text table_structure_options: TableStructureOptions = TableStructureOptions() - ocr_options: OcrOptions = OcrAutoOptions() layout_options: LayoutOptions = LayoutOptions() images_scale: float = 1.0 diff --git a/docling/models/ocr_enrichment.py b/docling/models/ocr_enrichment.py new file mode 100644 index 000000000..9e7855885 --- /dev/null +++ b/docling/models/ocr_enrichment.py @@ -0,0 +1,65 @@ +from collections.abc import Iterable +from pathlib import Path +from typing import List, Optional, Type, Union + +from docling_core.types.doc import ( + DoclingDocument, + NodeItem, + PictureItem, +) +from PIL import Image + +from docling.datamodel.accelerator_options import AcceleratorOptions +from docling.datamodel.pipeline_options import ( + OcrOptions, +) +from docling.models.base_model import ( + BaseItemAndImageEnrichmentModel, + ItemAndImageEnrichmentElement, +) +from docling.models.base_ocr_model import BaseOcrModel +from docling.models.factories import get_ocr_factory + + +class OcrEnrichmentModel(BaseItemAndImageEnrichmentModel): + images_scale: float = 2.0 + + def __init__( + self, + *, + enabled: bool, + artifacts_path: Optional[Union[Path, str]], + options: OcrOptions, + accelerator_options: AcceleratorOptions, + allow_external_plugins: bool, + ): + self.enabled = enabled + self.options = options + + self._ocr_model: BaseOcrModel + + if self.enabled: + ocr_factory = get_ocr_factory(allow_external_plugins=allow_external_plugins) + self._ocr_model = ocr_factory.create_instance( + options=self.options, + enabled=True, + artifacts_path=artifacts_path, + accelerator_options=accelerator_options, + ) + + def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool: + return self.enabled and isinstance(element, PictureItem) + + def __call__( + self, + doc: DoclingDocument, + element_batch: Iterable[ItemAndImageEnrichmentElement], + ) -> Iterable[NodeItem]: + if not self.enabled: + for element in element_batch: + yield element.item + return + + # TODO: call self._ocr_model + for element in element_batch: + yield element.item diff --git a/docling/pipeline/simple_pipeline.py b/docling/pipeline/simple_pipeline.py index 0e3f1b6f9..98b139c33 100644 --- a/docling/pipeline/simple_pipeline.py +++ b/docling/pipeline/simple_pipeline.py @@ -7,6 +7,7 @@ from docling.datamodel.base_models import ConversionStatus from docling.datamodel.document import ConversionResult from docling.datamodel.pipeline_options import ConvertPipelineOptions +from docling.models.ocr_enrichment import OcrEnrichmentModel from docling.pipeline.base_pipeline import ConvertPipeline from docling.utils.profiling import ProfilingScope, TimeRecorder @@ -23,6 +24,17 @@ class SimplePipeline(ConvertPipeline): def __init__(self, pipeline_options: ConvertPipelineOptions): super().__init__(pipeline_options) + self.enrichment_pipe.insert( + 0, + OcrEnrichmentModel( + enabled=self.pipeline_options.do_ocr, + options=self.pipeline_options.ocr_options, + allow_external_plugins=self.pipeline_options.allow_external_plugins, + artifacts_path=self.pipeline_options.artifacts_path, + accelerator_options=self.pipeline_options.accelerator_options, + ), + ) + def _build_document(self, conv_res: ConversionResult) -> ConversionResult: if not isinstance(conv_res.input._backend, DeclarativeDocumentBackend): raise RuntimeError(