diff --git a/docling/datamodel/pipeline_options_vlm_model.py b/docling/datamodel/pipeline_options_vlm_model.py index 9b03d58a9f..3e80243c51 100644 --- a/docling/datamodel/pipeline_options_vlm_model.py +++ b/docling/datamodel/pipeline_options_vlm_model.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union from docling_core.types.doc.page import SegmentedPage from pydantic import AnyUrl, BaseModel, ConfigDict @@ -9,6 +9,11 @@ from docling.datamodel.accelerator_options import AcceleratorDevice from docling.models.utils.generation_utils import GenerationStopper +if TYPE_CHECKING: + from docling_core.types.doc.page import SegmentedPage + + from docling.datamodel.base_models import Page + class BaseVlmOptions(BaseModel): kind: str @@ -17,7 +22,7 @@ class BaseVlmOptions(BaseModel): max_size: Optional[int] = None temperature: float = 0.0 - def build_prompt(self, page: Optional[SegmentedPage]) -> str: + def build_prompt(self, page: Optional[Union["Page", "SegmentedPage"]]) -> str: return self.prompt def decode_response(self, text: str) -> str: diff --git a/docling/experimental/__init__.py b/docling/experimental/__init__.py new file mode 100644 index 0000000000..e21e5131b8 --- /dev/null +++ b/docling/experimental/__init__.py @@ -0,0 +1,5 @@ +"""Experimental modules for Docling. + +This package contains experimental features that are under development +and may change or be removed in future versions. +""" diff --git a/docling/experimental/datamodel/__init__.py b/docling/experimental/datamodel/__init__.py new file mode 100644 index 0000000000..c76b060ae6 --- /dev/null +++ b/docling/experimental/datamodel/__init__.py @@ -0,0 +1 @@ +"""Experimental datamodel modules.""" diff --git a/docling/experimental/datamodel/threaded_layout_vlm_pipeline_options.py b/docling/experimental/datamodel/threaded_layout_vlm_pipeline_options.py new file mode 100644 index 0000000000..94d6ac3b5b --- /dev/null +++ b/docling/experimental/datamodel/threaded_layout_vlm_pipeline_options.py @@ -0,0 +1,31 @@ +"""Options for the threaded layout+VLM pipeline.""" + +from typing import Union + +from docling.datamodel.layout_model_specs import DOCLING_LAYOUT_HERON +from docling.datamodel.pipeline_options import LayoutOptions, PaginatedPipelineOptions +from docling.datamodel.pipeline_options_vlm_model import ( + ApiVlmOptions, + InlineVlmOptions, +) +from docling.datamodel.vlm_model_specs import GRANITEDOCLING_TRANSFORMERS + + +class ThreadedLayoutVlmPipelineOptions(PaginatedPipelineOptions): + """Pipeline options for the threaded layout+VLM pipeline.""" + + images_scale: float = 2.0 + + # VLM configuration (will be enhanced with layout awareness by the pipeline) + vlm_options: Union[InlineVlmOptions, ApiVlmOptions] = GRANITEDOCLING_TRANSFORMERS + + # Layout model configuration + layout_options: LayoutOptions = LayoutOptions( + model_spec=DOCLING_LAYOUT_HERON, skip_cell_assignment=True + ) + + # Threading and batching controls + layout_batch_size: int = 4 + vlm_batch_size: int = 4 + batch_timeout_seconds: float = 2.0 + queue_max_size: int = 50 diff --git a/docling/experimental/demo_layout_vlm.py b/docling/experimental/demo_layout_vlm.py new file mode 100644 index 0000000000..dbb9702656 --- /dev/null +++ b/docling/experimental/demo_layout_vlm.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python3 +"""Demo script for the new ThreadedLayoutVlmPipeline. + +This script demonstrates the usage of the new pipeline that combines +layout model preprocessing with VLM processing in a threaded manner. +""" + +import argparse +import logging +from io import BytesIO +from pathlib import Path + +from docling.datamodel.base_models import ConversionStatus, DocumentStream, InputFormat +from docling.datamodel.pipeline_options import VlmPipelineOptions +from docling.datamodel.vlm_model_specs import ( + GRANITEDOCLING_TRANSFORMERS, + GRANITEDOCLING_VLLM, +) +from docling.document_converter import DocumentConverter, PdfFormatOption +from docling.experimental.datamodel.threaded_layout_vlm_pipeline_options import ( + ThreadedLayoutVlmPipelineOptions, +) +from docling.experimental.pipeline.threaded_layout_vlm_pipeline import ( + ThreadedLayoutVlmPipeline, +) +from docling.pipeline.vlm_pipeline import VlmPipeline + +_log = logging.getLogger(__name__) + + +def _parse_args(): + parser = argparse.ArgumentParser( + description="Demo script for the new ThreadedLayoutVlmPipeline" + ) + parser.add_argument( + "--input", type=str, required=True, help="Input directory containing PDF files" + ) + parser.add_argument( + "--output", + type=str, + default="../results/", + help="Output directory for converted files", + ) + return parser.parse_args() + + +def _get_docs(input_doc_paths): + """Yield DocumentStream objects from list of input document paths""" + for path in input_doc_paths: + buf = BytesIO(path.read_bytes()) + stream = DocumentStream(name=path.name, stream=buf) + yield stream + + +def demo_threaded_layout_vlm_pipeline( + input_doc_paths: list[Path], out_dir_layout_aware: Path, out_dir_classic_vlm: Path +): + """Demonstrate the threaded layout+VLM pipeline.""" + + # Configure pipeline options + print("Configuring pipeline options...") + pipeline_options_layout_aware = ThreadedLayoutVlmPipelineOptions( + # VLM configuration - defaults to GRANITEDOCLING_TRANSFORMERS + vlm_options=GRANITEDOCLING_TRANSFORMERS, + # Layout configuration - defaults to DOCLING_LAYOUT_HERON + # Batch sizes for parallel processing + layout_batch_size=2, + vlm_batch_size=1, + # Queue configuration + queue_max_size=10, + batch_timeout_seconds=1.0, + # Layout coordinate injection + include_layout_coordinates=True, + coordinate_precision=1, + # Image processing + images_scale=2.0, + generate_page_images=True, + ) + + pipeline_options_classic_vlm = VlmPipelineOptions(vlm_otpions=GRANITEDOCLING_VLLM) + + # Create converter with the new pipeline + print("Initializing DocumentConverter (this may take a while - loading models)...") + doc_converter_layout_enhanced = DocumentConverter( + format_options={ + InputFormat.PDF: PdfFormatOption( + pipeline_cls=ThreadedLayoutVlmPipeline, + pipeline_options=pipeline_options_layout_aware, + ) + } + ) + doc_converter_classic_vlm = DocumentConverter( + format_options={ + InputFormat.PDF: PdfFormatOption( + pipeline_cls=VlmPipeline, + pipeline_options=pipeline_options_classic_vlm, + ), + } + ) + + print(f"Starting conversion of {len(input_doc_paths)} document(s)...") + result_layout_aware = doc_converter_layout_enhanced.convert_all( + list(_get_docs(input_doc_paths)), raises_on_error=False + ) + result_without_layout = doc_converter_classic_vlm.convert_all( + list(_get_docs(input_doc_paths)), raises_on_error=False + ) + + for conv_result in result_layout_aware: + if conv_result.status == ConversionStatus.FAILURE: + _log.error(f"Conversion failed: {conv_result.status}") + continue + + doc_filename = conv_result.input.file.stem + conv_result.document.save_as_doctags( + out_dir_layout_aware / f"{doc_filename}.dt" + ) + + for conv_result in result_without_layout: + if conv_result.status == ConversionStatus.FAILURE: + _log.error(f"Conversion failed: {conv_result.status}") + continue + + doc_filename = conv_result.input.file.stem + conv_result.document.save_as_doctags(out_dir_classic_vlm / f"{doc_filename}.dt") + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + try: + print("Starting script...") + args = _parse_args() + print(f"Parsed arguments: input={args.input}, output={args.output}") + + base_path = Path(args.input) + + print(f"Searching for PDFs in: {base_path}") + input_doc_paths = sorted(base_path.rglob("*.*")) + input_doc_paths = [ + e + for e in input_doc_paths + if e.name.endswith(".pdf") or e.name.endswith(".PDF") + ] + + if not input_doc_paths: + _log.error(f"ERROR: No PDF files found in {base_path}") + + print(f"Found {len(input_doc_paths)} PDF file(s):") + + out_dir_layout_aware = ( + Path(args.output) / "layout_aware" / "model_output" / "layout" / "doc_tags" + ) + out_dir_classic_vlm = ( + Path(args.output) / "classic_vlm" / "model_output" / "layout" / "doc_tags" + ) + out_dir_layout_aware.mkdir(parents=True, exist_ok=True) + out_dir_classic_vlm.mkdir(parents=True, exist_ok=True) + + _log.info("Calling demo_threaded_layout_vlm_pipeline...") + demo_threaded_layout_vlm_pipeline( + input_doc_paths, out_dir_layout_aware, out_dir_classic_vlm + ) + _log.info("Script completed successfully!") + except Exception as e: + print(f"ERROR: {type(e).__name__}: {e}") + import traceback + + traceback.print_exc() + raise diff --git a/docling/experimental/pipeline/__init__.py b/docling/experimental/pipeline/__init__.py new file mode 100644 index 0000000000..9a9667633b --- /dev/null +++ b/docling/experimental/pipeline/__init__.py @@ -0,0 +1 @@ +"""Experimental pipeline modules.""" diff --git a/docling/experimental/pipeline/threaded_layout_vlm_pipeline.py b/docling/experimental/pipeline/threaded_layout_vlm_pipeline.py new file mode 100644 index 0000000000..d68996f751 --- /dev/null +++ b/docling/experimental/pipeline/threaded_layout_vlm_pipeline.py @@ -0,0 +1,397 @@ +"""Threaded Layout+VLM Pipeline +================================ +A specialized two-stage threaded pipeline that combines layout model preprocessing +with VLM processing. The layout model detects document elements and coordinates, +which are then injected into the VLM prompt for enhanced structured output. +""" + +from __future__ import annotations + +import itertools +import logging +from pathlib import Path +from typing import Iterable, List, Optional, Union, cast + +from docling_core.types.doc import DoclingDocument +from docling_core.types.doc.document import DocTagsDocument +from PIL import Image as PILImage + +from docling.backend.abstract_backend import AbstractDocumentBackend +from docling.backend.pdf_backend import PdfDocumentBackend +from docling.datamodel.base_models import ConversionStatus, Page +from docling.datamodel.document import ConversionResult +from docling.datamodel.pipeline_options_vlm_model import ( + ApiVlmOptions, + InferenceFramework, + InlineVlmOptions, +) +from docling.datamodel.settings import settings +from docling.experimental.datamodel.threaded_layout_vlm_pipeline_options import ( + ThreadedLayoutVlmPipelineOptions, +) +from docling.models.api_vlm_model import ApiVlmModel +from docling.models.base_model import BaseVlmPageModel +from docling.models.layout_model import LayoutModel +from docling.models.vlm_models_inline.hf_transformers_model import ( + HuggingFaceTransformersVlmModel, +) +from docling.models.vlm_models_inline.mlx_model import HuggingFaceMlxModel +from docling.pipeline.base_pipeline import BasePipeline +from docling.pipeline.standard_pdf_pipeline import ( + ProcessingResult, + RunContext, + ThreadedItem, + ThreadedPipelineStage, + ThreadedQueue, +) +from docling.utils.profiling import ProfilingScope, TimeRecorder + +_log = logging.getLogger(__name__) + + +class ThreadedLayoutVlmPipeline(BasePipeline): + """Two-stage threaded pipeline: Layout Model → VLM Model.""" + + def __init__(self, pipeline_options: ThreadedLayoutVlmPipelineOptions) -> None: + super().__init__(pipeline_options) + self.pipeline_options: ThreadedLayoutVlmPipelineOptions = pipeline_options + self._run_seq = itertools.count(1) # deterministic, monotonic run ids + + # VLM model type (initialized in _init_models) + self.vlm_model: BaseVlmPageModel + + # Initialize models + self._init_models() + + def _init_models(self) -> None: + """Initialize layout and VLM models.""" + art_path = self._resolve_artifacts_path() + + # Layout model + self.layout_model = LayoutModel( + artifacts_path=art_path, + accelerator_options=self.pipeline_options.accelerator_options, + options=self.pipeline_options.layout_options, + ) + + # VLM model based on options type + # Create layout-aware VLM options internally + base_vlm_options = self.pipeline_options.vlm_options + + class LayoutAwareVlmOptions(type(base_vlm_options)): # type: ignore[misc] + def build_prompt(self, page): + from docling.datamodel.base_models import Page + + base_prompt = self.prompt + + # If we have a full Page object with layout predictions, enhance the prompt + if isinstance(page, Page) and page.predictions.layout: + from docling_core.types.doc.tokens import DocumentToken + + layout_elements = [] + for cluster in page.predictions.layout.clusters: + # Get proper tag name from DocItemLabel + tag_name = DocumentToken.create_token_name_from_doc_item_label( + label=cluster.label + ) + + # Convert bbox to tuple and get location tokens + bbox_tuple = cluster.bbox.as_tuple() + location_tokens = DocumentToken.get_location( + bbox=bbox_tuple, + page_w=page.size.width, + page_h=page.size.height, + xsize=500, + ysize=500, + ) + + # Create XML element with DocTags format + xml_element = f"<{tag_name}>{location_tokens}" + layout_elements.append(xml_element) + + if layout_elements: + # Join elements with newlines and wrap in layout tags + layout_xml = ( + "" + "\n".join(layout_elements) + "" + ) + layout_injection = f"{layout_xml}" + + custom_prompt = base_prompt + layout_injection + print(f"Layout injection prompt: {custom_prompt}") + + return custom_prompt + + return base_prompt + + vlm_options = LayoutAwareVlmOptions(**base_vlm_options.model_dump()) + + if isinstance(base_vlm_options, ApiVlmOptions): + self.vlm_model = ApiVlmModel( + enabled=True, + enable_remote_services=self.pipeline_options.enable_remote_services, + vlm_options=vlm_options, + ) + elif isinstance(base_vlm_options, InlineVlmOptions): + if vlm_options.inference_framework == InferenceFramework.TRANSFORMERS: + self.vlm_model = HuggingFaceTransformersVlmModel( + enabled=True, + artifacts_path=art_path, + accelerator_options=self.pipeline_options.accelerator_options, + vlm_options=vlm_options, + ) + elif vlm_options.inference_framework == InferenceFramework.MLX: + self.vlm_model = HuggingFaceMlxModel( + enabled=True, + artifacts_path=art_path, + accelerator_options=self.pipeline_options.accelerator_options, + vlm_options=vlm_options, + ) + elif vlm_options.inference_framework == InferenceFramework.VLLM: + from docling.models.vlm_models_inline.vllm_model import VllmVlmModel + + self.vlm_model = VllmVlmModel( + enabled=True, + artifacts_path=art_path, + accelerator_options=self.pipeline_options.accelerator_options, + vlm_options=vlm_options, + ) + else: + raise ValueError( + f"Unsupported VLM inference framework: {vlm_options.inference_framework}" + ) + else: + raise ValueError(f"Unsupported VLM options type: {type(base_vlm_options)}") + + def _resolve_artifacts_path(self) -> Optional[Path]: + """Resolve artifacts path from options or settings.""" + if self.pipeline_options.artifacts_path: + p = Path(self.pipeline_options.artifacts_path).expanduser() + elif settings.artifacts_path: + p = Path(settings.artifacts_path).expanduser() + else: + return None + if not p.is_dir(): + raise RuntimeError( + f"{p} does not exist or is not a directory containing the required models" + ) + return p + + def _create_run_ctx(self) -> RunContext: + """Create pipeline stages and wire them together.""" + opts = self.pipeline_options + + # Layout stage + layout_stage = ThreadedPipelineStage( + name="layout", + model=self.layout_model, + batch_size=opts.layout_batch_size, + batch_timeout=opts.batch_timeout_seconds, + queue_max_size=opts.queue_max_size, + ) + + # VLM stage - now layout-aware through enhanced build_prompt + vlm_stage = ThreadedPipelineStage( + name="vlm", + model=self.vlm_model, + batch_size=opts.vlm_batch_size, + batch_timeout=opts.batch_timeout_seconds, + queue_max_size=opts.queue_max_size, + ) + + # Wire stages + output_q = ThreadedQueue(opts.queue_max_size) + layout_stage.add_output_queue(vlm_stage.input_queue) + vlm_stage.add_output_queue(output_q) + + stages = [layout_stage, vlm_stage] + return RunContext( + stages=stages, first_stage=layout_stage, output_queue=output_q + ) + + def _build_document(self, conv_res: ConversionResult) -> ConversionResult: + """Build document using threaded layout+VLM pipeline.""" + run_id = next(self._run_seq) + assert isinstance(conv_res.input._backend, PdfDocumentBackend) + backend = conv_res.input._backend + + # Initialize pages + start_page, end_page = conv_res.input.limits.page_range + pages: List[Page] = [] + for i in range(conv_res.input.page_count): + if start_page - 1 <= i <= end_page - 1: + page = Page(page_no=i) + page._backend = backend.load_page(i) + if page._backend and page._backend.is_valid(): + page.size = page._backend.get_size() + conv_res.pages.append(page) + pages.append(page) + + if not pages: + conv_res.status = ConversionStatus.FAILURE + return conv_res + + total_pages = len(pages) + ctx = self._create_run_ctx() + for st in ctx.stages: + st.start() + + proc = ProcessingResult(total_expected=total_pages) + fed_idx = 0 + batch_size = 32 + + try: + while proc.success_count + proc.failure_count < total_pages: + # Feed pages to first stage + while fed_idx < total_pages: + ok = ctx.first_stage.input_queue.put( + ThreadedItem( + payload=pages[fed_idx], + run_id=run_id, + page_no=pages[fed_idx].page_no, + conv_res=conv_res, + ), + timeout=0.0, + ) + if ok: + fed_idx += 1 + if fed_idx == total_pages: + ctx.first_stage.input_queue.close() + else: + break + + # Drain results from output + out_batch = ctx.output_queue.get_batch(batch_size, timeout=0.05) + for itm in out_batch: + if itm.run_id != run_id: + continue + if itm.is_failed or itm.error: + proc.failed_pages.append( + (itm.page_no, itm.error or RuntimeError("unknown error")) + ) + else: + assert itm.payload is not None + proc.pages.append(itm.payload) + + # Handle early termination + if not out_batch and ctx.output_queue.closed: + missing = total_pages - (proc.success_count + proc.failure_count) + if missing > 0: + proc.failed_pages.extend( + [(-1, RuntimeError("pipeline terminated early"))] * missing + ) + break + finally: + for st in ctx.stages: + st.stop() + ctx.output_queue.close() + + self._integrate_results(conv_res, proc) + return conv_res + + def _integrate_results( + self, conv_res: ConversionResult, proc: ProcessingResult + ) -> None: + """Integrate processing results into conversion result.""" + page_map = {p.page_no: p for p in proc.pages} + conv_res.pages = [ + page_map.get(p.page_no, p) + for p in conv_res.pages + if p.page_no in page_map + or not any(fp == p.page_no for fp, _ in proc.failed_pages) + ] + + if proc.is_complete_failure: + conv_res.status = ConversionStatus.FAILURE + elif proc.is_partial_success: + conv_res.status = ConversionStatus.PARTIAL_SUCCESS + else: + conv_res.status = ConversionStatus.SUCCESS + + # Clean up images if not needed + if not self.pipeline_options.generate_page_images: + for p in conv_res.pages: + p._image_cache = {} + + def _assemble_document(self, conv_res: ConversionResult) -> ConversionResult: + """Assemble final document from VLM predictions.""" + from docling_core.types.doc import DocItem, ImageRef, PictureItem + + from docling.datamodel.pipeline_options_vlm_model import ResponseFormat + + with TimeRecorder(conv_res, "doc_assemble", scope=ProfilingScope.DOCUMENT): + # Assemble document using DOCTAGS format only + if ( + self.pipeline_options.vlm_options.response_format + == ResponseFormat.DOCTAGS + ): + conv_res.document = self._turn_dt_into_doc(conv_res) + else: + raise RuntimeError( + f"Unsupported VLM response format {self.pipeline_options.vlm_options.response_format}. Only DOCTAGS format is supported." + ) + + # Generate images of the requested element types + if self.pipeline_options.generate_picture_images: + scale = self.pipeline_options.images_scale + for element, _level in conv_res.document.iterate_items(): + if not isinstance(element, DocItem) or len(element.prov) == 0: + continue + if ( + isinstance(element, PictureItem) + and self.pipeline_options.generate_picture_images + ): + page_ix = element.prov[0].page_no - 1 + page = conv_res.pages[page_ix] + assert page.size is not None + assert page.image is not None + + crop_bbox = ( + element.prov[0] + .bbox.scaled(scale=scale) + .to_top_left_origin(page_height=page.size.height * scale) + ) + + cropped_im = page.image.crop(crop_bbox.as_tuple()) + element.image = ImageRef.from_pil( + cropped_im, dpi=int(72 * scale) + ) + + return conv_res + + def _turn_dt_into_doc(self, conv_res: ConversionResult) -> DoclingDocument: + """Convert DOCTAGS response format to DoclingDocument.""" + doctags_list = [] + image_list = [] + for page in conv_res.pages: + # Only include pages that have both an image and VLM predictions + if page.image and page.predictions.vlm_response: + predicted_doctags = page.predictions.vlm_response.text + image_list.append(page.image) + doctags_list.append(predicted_doctags) + + doctags_list_c = cast(List[Union[Path, str]], doctags_list) + image_list_c = cast(List[Union[Path, PILImage.Image]], image_list) + doctags_doc = DocTagsDocument.from_doctags_and_image_pairs( + doctags_list_c, image_list_c + ) + document = DoclingDocument.load_from_doctags(doctag_document=doctags_doc) + + return document + + @classmethod + def get_default_options(cls) -> ThreadedLayoutVlmPipelineOptions: + return ThreadedLayoutVlmPipelineOptions() + + @classmethod + def is_backend_supported(cls, backend: AbstractDocumentBackend) -> bool: + return isinstance(backend, PdfDocumentBackend) + + def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus: + return conv_res.status + + def _unload(self, conv_res: ConversionResult) -> None: + for p in conv_res.pages: + if p._backend is not None: + p._backend.unload() + if conv_res.input._backend: + conv_res.input._backend.unload() diff --git a/docling/models/api_vlm_model.py b/docling/models/api_vlm_model.py index 2c9a1f9a78..b34c275a36 100644 --- a/docling/models/api_vlm_model.py +++ b/docling/models/api_vlm_model.py @@ -1,13 +1,15 @@ from collections.abc import Iterable from concurrent.futures import ThreadPoolExecutor +from typing import Union -from transformers import StoppingCriteria +import numpy as np +from PIL.Image import Image from docling.datamodel.base_models import Page, VlmPrediction, VlmStopReason from docling.datamodel.document import ConversionResult from docling.datamodel.pipeline_options_vlm_model import ApiVlmOptions from docling.exceptions import OperationNotAllowed -from docling.models.base_model import BasePageModel +from docling.models.base_model import BaseVlmPageModel from docling.models.utils.generation_utils import GenerationStopper from docling.utils.api_image_request import ( api_image_request, @@ -16,7 +18,10 @@ from docling.utils.profiling import TimeRecorder -class ApiVlmModel(BasePageModel): +class ApiVlmModel(BaseVlmPageModel): + # Override the vlm_options type annotation from BaseVlmPageModel + vlm_options: ApiVlmOptions # type: ignore[assignment] + def __init__( self, enabled: bool, @@ -43,66 +48,130 @@ def __init__( def __call__( self, conv_res: ConversionResult, page_batch: Iterable[Page] ) -> Iterable[Page]: - def _vlm_request(page): + page_list = list(page_batch) + if not page_list: + return + + original_order = page_list[:] + valid_pages = [] + + for page in page_list: assert page._backend is not None - if not page._backend.is_valid(): - return page + if page._backend.is_valid(): + valid_pages.append(page) + # Process valid pages in batch + if valid_pages: with TimeRecorder(conv_res, "vlm"): - assert page.size is not None + # Prepare images and prompts for batch processing + images = [] + prompts = [] + pages_with_images = [] + + for page in valid_pages: + assert page.size is not None + hi_res_image = page.get_image( + scale=self.vlm_options.scale, max_size=self.vlm_options.max_size + ) + + # Only process pages with valid images + if hi_res_image is not None: + images.append(hi_res_image) + prompt = self.vlm_options.build_prompt( + page.parsed_page + ) # ask christoph + prompts.append(prompt) + pages_with_images.append(page) + + # Use process_images for the actual inference + if images: # Only if we have valid images + predictions = list(self.process_images(images, prompts)) + + # Attach results to pages + for page, prediction in zip(pages_with_images, predictions): + page.predictions.vlm_response = prediction + + # Yield pages preserving original order + for page in original_order: + yield page + + def process_images( + self, + image_batch: Iterable[Union[Image, np.ndarray]], + prompt: Union[str, list[str]], + ) -> Iterable[VlmPrediction]: + """Process raw images without page metadata.""" + images = list(image_batch) - hi_res_image = page.get_image( - scale=self.vlm_options.scale, max_size=self.vlm_options.max_size + # Handle prompt parameter + if isinstance(prompt, str): + prompts = [prompt] * len(images) + elif isinstance(prompt, list): + if len(prompt) != len(images): + raise ValueError( + f"Prompt list length ({len(prompt)}) must match image count ({len(images)})" ) - assert hi_res_image is not None - if hi_res_image and hi_res_image.mode != "RGB": - hi_res_image = hi_res_image.convert("RGB") - - prompt = self.vlm_options.build_prompt(page.parsed_page) - stop_reason = VlmStopReason.UNSPECIFIED - - if self.vlm_options.custom_stopping_criteria: - # Instantiate any GenerationStopper classes before passing to streaming - instantiated_stoppers = [] - for criteria in self.vlm_options.custom_stopping_criteria: - if isinstance(criteria, GenerationStopper): - instantiated_stoppers.append(criteria) - elif isinstance(criteria, type) and issubclass( - criteria, GenerationStopper - ): - instantiated_stoppers.append(criteria()) - # Skip non-GenerationStopper criteria (should have been caught in validation) - - # Streaming path with early abort support - with TimeRecorder(conv_res, "vlm_inference"): - page_tags, num_tokens = api_image_request_streaming( - image=hi_res_image, - prompt=prompt, - url=self.vlm_options.url, - timeout=self.timeout, - headers=self.vlm_options.headers, - generation_stoppers=instantiated_stoppers, - **self.params, - ) - page_tags = self.vlm_options.decode_response(page_tags) + prompts = prompt + + def _process_single_image(image_prompt_pair): + image, prompt_text = image_prompt_pair + + # Convert numpy array to PIL Image if needed + if isinstance(image, np.ndarray): + if image.ndim == 3 and image.shape[2] in [3, 4]: + from PIL import Image as PILImage + + image = PILImage.fromarray(image.astype(np.uint8)) + elif image.ndim == 2: + from PIL import Image as PILImage + + image = PILImage.fromarray(image.astype(np.uint8), mode="L") else: - # Non-streaming fallback (existing behavior) - with TimeRecorder(conv_res, "vlm_inference"): - page_tags, num_tokens, stop_reason = api_image_request( - image=hi_res_image, - prompt=prompt, - url=self.vlm_options.url, - timeout=self.timeout, - headers=self.vlm_options.headers, - **self.params, - ) - - page_tags = self.vlm_options.decode_response(page_tags) - - page.predictions.vlm_response = VlmPrediction( - text=page_tags, num_tokens=num_tokens, stop_reason=stop_reason + raise ValueError(f"Unsupported numpy array shape: {image.shape}") + + # Ensure image is in RGB mode + if image.mode != "RGB": + image = image.convert("RGB") + + stop_reason = VlmStopReason.UNSPECIFIED + + if self.vlm_options.custom_stopping_criteria: # Ask christoph + # Instantiate any GenerationStopper classes before passing to streaming + instantiated_stoppers = [] + for criteria in self.vlm_options.custom_stopping_criteria: + if isinstance(criteria, GenerationStopper): + instantiated_stoppers.append(criteria) + elif isinstance(criteria, type) and issubclass( + criteria, GenerationStopper + ): + instantiated_stoppers.append(criteria()) + # Skip non-GenerationStopper criteria (should have been caught in validation) + + # Streaming path with early abort support + page_tags, num_tokens = api_image_request_streaming( + image=image, + prompt=prompt_text, + url=self.vlm_options.url, + timeout=self.timeout, + headers=self.vlm_options.headers, + generation_stoppers=instantiated_stoppers, + **self.params, ) - return page + else: + # Non-streaming fallback (existing behavior) + page_tags, num_tokens, stop_reason = api_image_request( + image=image, + prompt=prompt_text, + url=self.vlm_options.url, + timeout=self.timeout, + headers=self.vlm_options.headers, + **self.params, + ) + + page_tags = self.vlm_options.decode_response(page_tags) + return VlmPrediction( + text=page_tags, num_tokens=num_tokens, stop_reason=stop_reason + ) with ThreadPoolExecutor(max_workers=self.concurrency) as executor: - yield from executor.map(_vlm_request, page_batch) + yield from executor.map(_process_single_image, zip(images, prompts)) diff --git a/docling/models/vlm_models_inline/hf_transformers_model.py b/docling/models/vlm_models_inline/hf_transformers_model.py index f9aefcb894..7e56daa19b 100644 --- a/docling/models/vlm_models_inline/hf_transformers_model.py +++ b/docling/models/vlm_models_inline/hf_transformers_model.py @@ -176,7 +176,7 @@ def __call__( images.append(hi_res_image) # Define prompt structure - user_prompt = self.vlm_options.build_prompt(page.parsed_page) + user_prompt = self.vlm_options.build_prompt(page) user_prompts.append(user_prompt) pages_with_images.append(page) diff --git a/docling/models/vlm_models_inline/mlx_model.py b/docling/models/vlm_models_inline/mlx_model.py index 871c19ba23..6e644c43ec 100644 --- a/docling/models/vlm_models_inline/mlx_model.py +++ b/docling/models/vlm_models_inline/mlx_model.py @@ -134,10 +134,7 @@ def __call__( images.append(hi_res_image) # Define prompt structure - if callable(self.vlm_options.prompt): - user_prompt = self.vlm_options.prompt(page.parsed_page) - else: - user_prompt = self.vlm_options.prompt + user_prompt = self.vlm_options.build_prompt(page) user_prompts.append(user_prompt) pages_with_images.append(page)