diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index b403da25b5..f327670c33 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -20,7 +20,7 @@ env: tests/test_asr_pipeline.py tests/test_threaded_pipeline.py PYTEST_TO_SKIP: |- - EXAMPLES_TO_SKIP: '^(batch_convert|compare_vlm_models|minimal|minimal_vlm_pipeline|minimal_asr_pipeline|export_multimodal|custom_convert|develop_picture_enrichment|rapidocr_with_custom_models|offline_convert|pictures_description|pictures_description_api|vlm_pipeline_api_model|granitedocling_repetition_stopping)\.py$' + EXAMPLES_TO_SKIP: '^(batch_convert|compare_vlm_models|minimal|minimal_vlm_pipeline|minimal_asr_pipeline|export_multimodal|custom_convert|develop_picture_enrichment|rapidocr_with_custom_models|offline_convert|pictures_description|pictures_description_api|vlm_pipeline_api_model|granitedocling_repetition_stopping|post_process_ocr_with_vlm)\.py$' jobs: lint: diff --git a/docs/examples/post_process_ocr_with_vlm.py b/docs/examples/post_process_ocr_with_vlm.py new file mode 100644 index 0000000000..04500c497b --- /dev/null +++ b/docs/examples/post_process_ocr_with_vlm.py @@ -0,0 +1,512 @@ +import argparse +import logging +import os +from collections.abc import Iterable +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Any, Optional, Union + +from docling_core.types.doc import ( + DoclingDocument, + ImageRefMode, + NodeItem, + TextItem, +) +from docling_core.types.doc.document import ( + ContentLayer, + DocItem, + GraphCell, + KeyValueItem, + PictureItem, + RichTableCell, + TableCell, + TableItem, +) +from PIL import Image +from PIL.ImageOps import crop +from pydantic import BaseModel, ConfigDict +from tqdm import tqdm + +from docling.backend.json.docling_json_backend import DoclingJSONBackend +from docling.datamodel.accelerator_options import AcceleratorOptions +from docling.datamodel.base_models import InputFormat, ItemAndImageEnrichmentElement +from docling.datamodel.document import ConversionResult +from docling.datamodel.pipeline_options import ( + ConvertPipelineOptions, + PdfPipelineOptions, + PictureDescriptionApiOptions, +) +from docling.document_converter import DocumentConverter, FormatOption, PdfFormatOption +from docling.exceptions import OperationNotAllowed +from docling.models.base_model import BaseModelWithOptions, GenericEnrichmentModel +from docling.pipeline.simple_pipeline import SimplePipeline +from docling.pipeline.standard_pdf_pipeline import StandardPdfPipeline +from docling.utils.api_image_request import api_image_request +from docling.utils.profiling import ProfilingScope, TimeRecorder +from docling.utils.utils import chunkify + +# Example on how to apply to Docling Document OCR as a post-processing with "nanonets-ocr2-3b" via LM Studio +# Requires LM Studio running inference server with "nanonets-ocr2-3b" model pre-loaded +# To run: +# uv run python docs/examples/post_process_ocr_with_vlm.py + +LM_STUDIO_URL = "http://localhost:1234/v1/chat/completions" +LM_STUDIO_MODEL = "nanonets-ocr2-3b" + +DEFAULT_PROMPT = "Extract the text from the above document as if you were reading it naturally. Output pure text, no html and no markdown. Pay attention on line breaks and don't miss text after line break. Put all text in one line." +VERBOSE = False +SHOW_IMAGE = False + + +class PostOcrEnrichmentElement(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + item: Union[DocItem, TableCell, RichTableCell, GraphCell] + image: list[ + Image.Image + ] # Needs to be an a list of images for multi-provenance elements + + +class PostOcrEnrichmentPipelineOptions(ConvertPipelineOptions): + api_options: PictureDescriptionApiOptions + + +class PostOcrEnrichmentPipeline(SimplePipeline): + def __init__(self, pipeline_options: PostOcrEnrichmentPipelineOptions): + super().__init__(pipeline_options) + self.pipeline_options: PostOcrEnrichmentPipelineOptions + + self.enrichment_pipe = [ + PostOcrApiEnrichmentModel( + enabled=True, + enable_remote_services=True, + artifacts_path=None, + options=self.pipeline_options.api_options, + accelerator_options=AcceleratorOptions(), + ) + ] + + @classmethod + def get_default_options(cls) -> PostOcrEnrichmentPipelineOptions: + return PostOcrEnrichmentPipelineOptions() + + def _enrich_document(self, conv_res: ConversionResult) -> ConversionResult: + def _prepare_elements( + conv_res: ConversionResult, model: GenericEnrichmentModel[Any] + ) -> Iterable[NodeItem]: + for doc_element, _level in conv_res.document.iterate_items( + traverse_pictures=True, + included_content_layers={ + ContentLayer.BODY, + ContentLayer.FURNITURE, + }, + ): # With all content layers, with traverse_pictures=True + prepared_elements = ( + model.prepare_element( # make this one yield multiple items. + conv_res=conv_res, element=doc_element + ) + ) + if prepared_elements is not None: + yield prepared_elements + + with TimeRecorder(conv_res, "doc_enrich", scope=ProfilingScope.DOCUMENT): + for model in self.enrichment_pipe: + for element_batch in chunkify( + _prepare_elements(conv_res, model), + model.elements_batch_size, + ): + for element in model( + doc=conv_res.document, element_batch=element_batch + ): # Must exhaust! + pass + return conv_res + + +class PostOcrApiEnrichmentModel( + GenericEnrichmentModel[PostOcrEnrichmentElement], BaseModelWithOptions +): + expansion_factor: float = 0.001 + + def prepare_element( + self, conv_res: ConversionResult, element: NodeItem + ) -> Optional[list[PostOcrEnrichmentElement]]: + if not self.is_processable(doc=conv_res.document, element=element): + return None + + allowed = (DocItem, TableItem, GraphCell) + assert isinstance(element, allowed) + + if isinstance(element, KeyValueItem): + # Yield from the graphCells inside here. + result = [] + for c in element.graph.cells: + element_prov = c.prov # Key / Value have only one provenance! + bbox = element_prov.bbox + page_ix = element_prov.page_no + bbox = bbox.scale_to_size( + old_size=conv_res.document.pages[page_ix].size, + new_size=conv_res.document.pages[page_ix].image.size, + ) + expanded_bbox = bbox.expand_by_scale( + x_scale=self.expansion_factor, y_scale=self.expansion_factor + ).to_top_left_origin( + page_height=conv_res.document.pages[page_ix].image.size.height + ) + + good_bbox = True + if ( + expanded_bbox.l > expanded_bbox.r + or expanded_bbox.t > expanded_bbox.b + ): + good_bbox = False + + if good_bbox: + cropped_image = conv_res.document.pages[ + page_ix + ].image.pil_image.crop(expanded_bbox.as_tuple()) + # cropped_image.show() + result.append( + PostOcrEnrichmentElement(item=c, image=[cropped_image]) + ) + return result + elif isinstance(element, TableItem): + element_prov = element.prov[0] + page_ix = element_prov.page_no + result = [] + for i, row in enumerate(element.data.grid): + for j, cell in enumerate(row): + if hasattr(cell, "bbox"): + if cell.bbox: + bbox = cell.bbox + bbox = bbox.scale_to_size( + old_size=conv_res.document.pages[page_ix].size, + new_size=conv_res.document.pages[page_ix].image.size, + ) + expanded_bbox = bbox.expand_by_scale( + x_scale=self.expansion_factor, + y_scale=self.expansion_factor, + ).to_top_left_origin( + page_height=conv_res.document.pages[ + page_ix + ].image.size.height + ) + + good_bbox = True + if ( + expanded_bbox.l > expanded_bbox.r + or expanded_bbox.t > expanded_bbox.b + ): + good_bbox = False + + if good_bbox: + cropped_image = conv_res.document.pages[ + page_ix + ].image.pil_image.crop(expanded_bbox.as_tuple()) + # cropped_image.show() + result.append( + PostOcrEnrichmentElement( + item=cell, image=[cropped_image] + ) + ) + return result + else: + multiple_crops = [] + # Crop the image form the page + for element_prov in element.prov: + # Iterate over provenances + bbox = element_prov.bbox + + page_ix = element_prov.page_no + bbox = bbox.scale_to_size( + old_size=conv_res.document.pages[page_ix].size, + new_size=conv_res.document.pages[page_ix].image.size, + ) + expanded_bbox = bbox.expand_by_scale( + x_scale=self.expansion_factor, y_scale=self.expansion_factor + ).to_top_left_origin( + page_height=conv_res.document.pages[page_ix].image.size.height + ) + + good_bbox = True + if ( + expanded_bbox.l > expanded_bbox.r + or expanded_bbox.t > expanded_bbox.b + ): + good_bbox = False + + if good_bbox: + cropped_image = conv_res.document.pages[ + page_ix + ].image.pil_image.crop(expanded_bbox.as_tuple()) + multiple_crops.append(cropped_image) + # cropped_image.show() + if len(multiple_crops) > 0: + return [PostOcrEnrichmentElement(item=element, image=multiple_crops)] + else: + return [] + + @classmethod + def get_options_type(cls) -> type[PictureDescriptionApiOptions]: + return PictureDescriptionApiOptions + + def __init__( + self, + *, + enabled: bool, + enable_remote_services: bool, + artifacts_path: Optional[Union[Path, str]], + options: PictureDescriptionApiOptions, + accelerator_options: AcceleratorOptions, + ): + self.enabled = enabled + self.options = options + self.concurrency = 4 + self.expansion_factor = 0.05 + self.elements_batch_size = 4 + self._accelerator_options = accelerator_options + self._artifacts_path = ( + Path(artifacts_path) if isinstance(artifacts_path, str) else artifacts_path + ) + + if self.enabled and not enable_remote_services: + raise OperationNotAllowed( + "Enable remote services by setting pipeline_options.enable_remote_services=True." + ) + + def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool: + return self.enabled + + def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]: + def _api_request(image: Image.Image) -> str: + return api_image_request( + image=image, + prompt=self.options.prompt, + url=self.options.url, + timeout=self.options.timeout, + headers=self.options.headers, + **self.options.params, + ) + + with ThreadPoolExecutor(max_workers=self.concurrency) as executor: + yield from executor.map(_api_request, images) + + def __call__( + self, + doc: DoclingDocument, + element_batch: Iterable[ItemAndImageEnrichmentElement], + ) -> Iterable[NodeItem]: + if not self.enabled: + for element in element_batch: + yield element.item + return + + elements: list[TextItem] = [] + images: list[Image.Image] = [] + img_ind_per_element: list[int] = [] + + for element_stack in element_batch: + for element in element_stack: + allowed = (DocItem, TableCell, RichTableCell, GraphCell) + assert isinstance(element.item, allowed) + for ind, img in enumerate(element.image): + elements.append(element.item) + images.append(img) + # images.append(element.image) + img_ind_per_element.append(ind) + + if not images: + return + + outputs = list(self._annotate_images(images)) + + for item, output, img_ind in zip(elements, outputs, img_ind_per_element): + # Sometimes model can return html tags, which are not strictly needed in our, so it's better to clean them + def clean_html_tags(text): + for tag in [ + "
| ", + "", + " |