11from collections .abc import Iterable
22from concurrent .futures import ThreadPoolExecutor
3+ from typing import Union
4+
5+ import numpy as np
6+ from PIL .Image import Image
37
48from docling .datamodel .base_models import Page , VlmPrediction
59from docling .datamodel .document import ConversionResult
610from docling .datamodel .pipeline_options_vlm_model import ApiVlmOptions
711from docling .exceptions import OperationNotAllowed
8- from docling .models .base_model import BasePageModel
12+ from docling .models .base_model import BaseVlmPageModel
913from docling .utils .api_image_request import api_image_request
1014from docling .utils .profiling import TimeRecorder
1115
1216
13- class ApiVlmModel (BasePageModel ):
17+ class ApiVlmModel (BaseVlmPageModel ):
18+ # Override the vlm_options type annotation from BaseVlmPageModel
19+ vlm_options : ApiVlmOptions # type: ignore[assignment]
20+
1421 def __init__ (
1522 self ,
1623 enabled : bool ,
@@ -37,36 +44,104 @@ def __init__(
3744 def __call__ (
3845 self , conv_res : ConversionResult , page_batch : Iterable [Page ]
3946 ) -> Iterable [Page ]:
40- def _vlm_request (page ):
47+ page_list = list (page_batch )
48+ if not page_list :
49+ return
50+
51+ valid_pages = []
52+ invalid_pages = []
53+
54+ for page in page_list :
4155 assert page ._backend is not None
4256 if not page ._backend .is_valid ():
43- return page
57+ invalid_pages . append ( page )
4458 else :
45- with TimeRecorder (conv_res , "vlm" ):
46- assert page .size is not None
59+ valid_pages .append (page )
60+
61+ # Process valid pages in batch
62+ if valid_pages :
63+ with TimeRecorder (conv_res , "vlm" ):
64+ # Prepare images and prompts for batch processing
65+ images = []
66+ prompts = []
67+ pages_with_images = []
4768
69+ for page in valid_pages :
70+ assert page .size is not None
4871 hi_res_image = page .get_image (
4972 scale = self .vlm_options .scale , max_size = self .vlm_options .max_size
5073 )
51- assert hi_res_image is not None
52- if hi_res_image :
53- if hi_res_image .mode != "RGB" :
54- hi_res_image = hi_res_image .convert ("RGB" )
55-
56- prompt = self .vlm_options .build_prompt (page .parsed_page )
57- page_tags = api_image_request (
58- image = hi_res_image ,
59- prompt = prompt ,
60- url = self .vlm_options .url ,
61- timeout = self .timeout ,
62- headers = self .vlm_options .headers ,
63- ** self .params ,
64- )
6574
66- page_tags = self .vlm_options .decode_response (page_tags )
67- page .predictions .vlm_response = VlmPrediction (text = page_tags )
75+ # Only process pages with valid images
76+ if hi_res_image is not None :
77+ images .append (hi_res_image )
78+ prompt = self .vlm_options .build_prompt (page )
79+ prompts .append (prompt )
80+ pages_with_images .append (page )
81+
82+ # Use process_images for the actual inference
83+ if images : # Only if we have valid images
84+ predictions = list (self .process_images (images , prompts ))
85+
86+ # Attach results to pages
87+ for page , prediction in zip (pages_with_images , predictions ):
88+ page .predictions .vlm_response = prediction
89+
90+ # Yield all pages (valid and invalid)
91+ for page in invalid_pages :
92+ yield page
93+ for page in valid_pages :
94+ yield page
95+
96+ def process_images (
97+ self ,
98+ image_batch : Iterable [Union [Image , np .ndarray ]],
99+ prompt : Union [str , list [str ]],
100+ ) -> Iterable [VlmPrediction ]:
101+ """Process raw images without page metadata."""
102+ images = list (image_batch )
103+
104+ # Handle prompt parameter
105+ if isinstance (prompt , str ):
106+ prompts = [prompt ] * len (images )
107+ elif isinstance (prompt , list ):
108+ if len (prompt ) != len (images ):
109+ raise ValueError (
110+ f"Prompt list length ({ len (prompt )} ) must match image count ({ len (images )} )"
111+ )
112+ prompts = prompt
113+
114+ def _process_single_image (image_prompt_pair ):
115+ image , prompt_text = image_prompt_pair
116+
117+ # Convert numpy array to PIL Image if needed
118+ if isinstance (image , np .ndarray ):
119+ if image .ndim == 3 and image .shape [2 ] in [3 , 4 ]:
120+ from PIL import Image as PILImage
121+
122+ image = PILImage .fromarray (image .astype (np .uint8 ))
123+ elif image .ndim == 2 :
124+ from PIL import Image as PILImage
125+
126+ image = PILImage .fromarray (image .astype (np .uint8 ), mode = "L" )
127+ else :
128+ raise ValueError (f"Unsupported numpy array shape: { image .shape } " )
129+
130+ # Ensure image is in RGB mode
131+ if image .mode != "RGB" :
132+ image = image .convert ("RGB" )
133+
134+ page_tags = api_image_request (
135+ image = image ,
136+ prompt = prompt_text ,
137+ url = self .vlm_options .url ,
138+ timeout = self .timeout ,
139+ headers = self .vlm_options .headers ,
140+ ** self .params ,
141+ )
68142
69- return page
143+ page_tags = self .vlm_options .decode_response (page_tags )
144+ return VlmPrediction (text = page_tags )
70145
71146 with ThreadPoolExecutor (max_workers = self .concurrency ) as executor :
72- yield from executor .map (_vlm_request , page_batch )
147+ yield from executor .map (_process_single_image , zip ( images , prompts ) )
0 commit comments