Skip to content

Commit 9abf0fd

Browse files
fix: honor picture description batching and scale options (#3132)
* fix: honor picture description batching and scale options Signed-off-by: Hassan Raza <raihassanraza10@gmail.com> * fix: address picture description review feedback Signed-off-by: Hassan Raza <raihassanraza10@gmail.com> * test: fix picture description vlm init on py314 Signed-off-by: Hassan Raza <raihassanraza10@gmail.com> * test: simplify picture description vlm init stub Signed-off-by: Hassan Raza <raihassanraza10@gmail.com> --------- Signed-off-by: Hassan Raza <raihassanraza10@gmail.com>
1 parent 4e650af commit 9abf0fd

File tree

5 files changed

+392
-24
lines changed

5 files changed

+392
-24
lines changed

docling/datamodel/pipeline_options.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -536,19 +536,21 @@ class PictureDescriptionBaseOptions(BaseOptions):
536536
batch_size: Annotated[
537537
int,
538538
Field(
539+
ge=1,
539540
description=(
540541
"Number of images to process in a single batch during picture description. Higher values improve "
541542
"throughput but increase memory usage. Adjust based on available GPU/CPU memory."
542-
)
543+
),
543544
),
544545
] = 8
545546
scale: Annotated[
546547
float,
547548
Field(
549+
gt=0,
548550
description=(
549551
"Scaling factor for image resolution before processing. Higher values (e.g., 2.0) provide more detail "
550552
"for the vision model but increase processing time and memory. Range: 0.5-4.0 typical."
551-
)
553+
),
552554
),
553555
] = 2.0
554556
picture_area_threshold: Annotated[
@@ -715,6 +717,15 @@ class PictureDescriptionVlmOptions(PictureDescriptionBaseOptions):
715717
)
716718
),
717719
] = {"max_new_tokens": 200, "do_sample": False}
720+
padding_side: Annotated[
721+
Literal["left", "right"],
722+
Field(
723+
description=(
724+
"Tokenizer padding side used for batched generation. Defaults to left to preserve the legacy "
725+
"behavior, but can be overridden for models that require right padding."
726+
)
727+
),
728+
] = "left"
718729

719730
@property
720731
def repo_cache_folder(self) -> str:

docling/models/picture_description_base_model.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,16 @@ def __init__(
3939
options: PictureDescriptionBaseOptions,
4040
accelerator_options: AcceleratorOptions,
4141
):
42+
if options.batch_size < 1:
43+
raise ValueError("Picture description batch_size must be >= 1")
44+
if options.scale <= 0:
45+
raise ValueError("Picture description scale must be > 0")
46+
4247
self.enabled = enabled
4348
self.options = options
4449
self.provenance = "not-implemented"
50+
self.elements_batch_size = options.batch_size
51+
self.images_scale = options.scale
4552

4653
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
4754
return self.enabled and isinstance(element, PictureItem)

docling/models/stages/picture_description/picture_description_vlm_model.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ def __init__(
5757
import torch
5858
from transformers import (
5959
AutoModelForImageTextToText,
60-
AutoModelForVision2Seq,
6160
AutoProcessor,
6261
)
6362
except ImportError:
@@ -68,6 +67,9 @@ def __init__(
6867
# Initialize processor and model
6968
with _model_init_lock:
7069
self.processor = AutoProcessor.from_pretrained(artifacts_path)
70+
tokenizer = getattr(self.processor, "tokenizer", None)
71+
if tokenizer is not None:
72+
tokenizer.padding_side = self.options.padding_side
7173
self.model = AutoModelForImageTextToText.from_pretrained(
7274
artifacts_path,
7375
device_map=self.device,
@@ -89,6 +91,10 @@ def __init__(
8991
def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
9092
from transformers import GenerationConfig
9193

94+
image_batch = list(images)
95+
if not image_batch:
96+
return
97+
9298
# Create input messages
9399
messages = [
94100
{
@@ -100,24 +106,25 @@ def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
100106
},
101107
]
102108

103-
# TODO: do batch generation
104-
105-
for image in images:
106-
# Prepare inputs
107-
prompt = self.processor.apply_chat_template(
108-
messages, add_generation_prompt=True
109-
)
110-
inputs = self.processor(text=prompt, images=[image], return_tensors="pt")
111-
inputs = inputs.to(self.device)
112-
113-
# Generate outputs
114-
generated_ids = self.model.generate(
115-
**inputs,
116-
generation_config=GenerationConfig(**self.options.generation_config),
117-
)
118-
generated_texts = self.processor.batch_decode(
119-
generated_ids[:, inputs["input_ids"].shape[1] :],
120-
skip_special_tokens=True,
121-
)
122-
123-
yield generated_texts[0].strip()
109+
prompt = self.processor.apply_chat_template(
110+
messages, add_generation_prompt=True
111+
)
112+
inputs = self.processor(
113+
text=[prompt] * len(image_batch),
114+
images=image_batch,
115+
return_tensors="pt",
116+
padding=True,
117+
)
118+
inputs = inputs.to(self.device)
119+
120+
generated_ids = self.model.generate(
121+
**inputs,
122+
generation_config=GenerationConfig(**self.options.generation_config),
123+
)
124+
generated_texts = self.processor.batch_decode(
125+
generated_ids[:, inputs["input_ids"].shape[1] :],
126+
skip_special_tokens=True,
127+
)
128+
129+
for text in generated_texts:
130+
yield text.strip()
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
from collections.abc import Iterable
2+
from types import SimpleNamespace
3+
from typing import ClassVar, List, Type
4+
5+
import pytest
6+
from docling_core.types.doc import (
7+
DoclingDocument,
8+
ImageRef,
9+
PictureItem,
10+
ProvenanceItem,
11+
)
12+
from docling_core.types.doc.base import BoundingBox, Size
13+
from PIL import Image
14+
15+
from docling.datamodel.accelerator_options import AcceleratorOptions
16+
from docling.datamodel.base_models import ItemAndImageEnrichmentElement
17+
from docling.datamodel.pipeline_options import (
18+
PictureDescriptionBaseOptions,
19+
PictureDescriptionVlmEngineOptions,
20+
PipelineOptions,
21+
)
22+
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
23+
from docling.pipeline.base_pipeline import BasePipeline
24+
25+
26+
class _TestOptions(PictureDescriptionBaseOptions):
27+
kind: ClassVar[str] = "test"
28+
29+
30+
class _ConfiguredPictureDescriptionModel(PictureDescriptionBaseModel):
31+
def __init__(self, options: PictureDescriptionBaseOptions) -> None:
32+
super().__init__(
33+
enabled=True,
34+
enable_remote_services=False,
35+
artifacts_path=None,
36+
options=options,
37+
accelerator_options=AcceleratorOptions(),
38+
)
39+
40+
@classmethod
41+
def get_options_type(cls) -> Type[PictureDescriptionBaseOptions]:
42+
return _TestOptions
43+
44+
def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
45+
for _image in images:
46+
yield "test description"
47+
48+
49+
class _BatchRecordingPictureDescriptionModel(_ConfiguredPictureDescriptionModel):
50+
def __init__(self, options: PictureDescriptionBaseOptions) -> None:
51+
super().__init__(options)
52+
self.batch_sizes: List[int] = []
53+
54+
def __call__(
55+
self,
56+
doc: DoclingDocument,
57+
element_batch: Iterable[ItemAndImageEnrichmentElement],
58+
) -> Iterable[PictureItem]:
59+
element_list = list(element_batch)
60+
self.batch_sizes.append(len(element_list))
61+
for element in element_list:
62+
assert isinstance(element.item, PictureItem)
63+
yield element.item
64+
65+
66+
class _PictureDescriptionPipeline(BasePipeline):
67+
def _build_document(self, conv_res):
68+
return conv_res
69+
70+
def _determine_status(self, conv_res):
71+
return conv_res.status
72+
73+
@classmethod
74+
def get_default_options(cls) -> PipelineOptions:
75+
return PipelineOptions()
76+
77+
@classmethod
78+
def is_backend_supported(cls, backend) -> bool:
79+
return True
80+
81+
82+
def _make_picture_doc(*, count: int, embed_images: bool = True) -> DoclingDocument:
83+
doc = DoclingDocument(name="test")
84+
for _ in range(count):
85+
image = (
86+
ImageRef.from_pil(Image.new("RGB", (20, 20), "red"), dpi=72)
87+
if embed_images
88+
else None
89+
)
90+
doc.add_picture(image=image)
91+
return doc
92+
93+
94+
def test_picture_description_options_control_batch_size_and_scale() -> None:
95+
model = _ConfiguredPictureDescriptionModel(_TestOptions(batch_size=3, scale=1.5))
96+
97+
assert model.elements_batch_size == 3
98+
assert model.images_scale == 1.5
99+
100+
101+
def test_picture_description_batch_size_controls_pipeline_chunking() -> None:
102+
pipeline = _PictureDescriptionPipeline(PipelineOptions())
103+
model = _BatchRecordingPictureDescriptionModel(_TestOptions(batch_size=2))
104+
pipeline.enrichment_pipe = [model]
105+
conv_res = SimpleNamespace(
106+
document=_make_picture_doc(count=5),
107+
timings={},
108+
status="success",
109+
)
110+
111+
pipeline._enrich_document(conv_res)
112+
113+
assert model.batch_sizes == [2, 2, 1]
114+
115+
116+
def test_picture_description_scale_is_used_for_cropping() -> None:
117+
model = _ConfiguredPictureDescriptionModel(_TestOptions(scale=1.5))
118+
doc = DoclingDocument(name="test")
119+
doc.add_page(page_no=1, size=Size(width=100, height=100))
120+
picture = doc.add_picture(
121+
prov=ProvenanceItem(
122+
page_no=1,
123+
bbox=BoundingBox(l=10, t=10, r=30, b=30),
124+
charspan=(0, 0),
125+
)
126+
)
127+
128+
class _PageSpy:
129+
def __init__(self):
130+
self.page_no = 1
131+
self.calls = []
132+
133+
def get_image(self, *, scale, cropbox):
134+
self.calls.append({"scale": scale, "cropbox": cropbox})
135+
return Image.new("RGB", (5, 5), "blue")
136+
137+
page = _PageSpy()
138+
conv_res = SimpleNamespace(document=doc, pages=[page])
139+
140+
prepared = model.prepare_element(conv_res=conv_res, element=picture)
141+
142+
assert prepared is not None
143+
assert page.calls[0]["scale"] == 1.5
144+
145+
146+
def test_picture_description_embedded_images_keep_original_size() -> None:
147+
model = _ConfiguredPictureDescriptionModel(_TestOptions(scale=1.5))
148+
doc = _make_picture_doc(count=1, embed_images=True)
149+
150+
prepared = model.prepare_element(
151+
conv_res=SimpleNamespace(document=doc, pages=[]), element=doc.pictures[0]
152+
)
153+
154+
assert prepared is not None
155+
assert prepared.image.size == (20, 20)
156+
157+
158+
def test_picture_description_batch_size_must_be_positive() -> None:
159+
with pytest.raises(ValueError):
160+
_TestOptions(batch_size=0)
161+
162+
163+
def test_picture_description_scale_must_be_positive() -> None:
164+
with pytest.raises(ValueError):
165+
_TestOptions(scale=0)
166+
167+
168+
def test_picture_description_preset_batch_size_must_be_positive() -> None:
169+
options = PictureDescriptionVlmEngineOptions.from_preset("smolvlm", batch_size=0)
170+
171+
with pytest.raises(ValueError, match="batch_size"):
172+
_ConfiguredPictureDescriptionModel(options)
173+
174+
175+
def test_picture_description_preset_scale_must_be_positive() -> None:
176+
options = PictureDescriptionVlmEngineOptions.from_preset("smolvlm", scale=0)
177+
178+
with pytest.raises(ValueError, match="scale"):
179+
_ConfiguredPictureDescriptionModel(options)

0 commit comments

Comments
 (0)