Skip to content

Commit 6a04e27

Browse files
ElHachem02Peter El Hachem
andauthored
feat(vlm): track generated tokens and stop reasons for VLM models (#2543)
* feat: add enum StopReason and use it in VlmPrediction Signed-off-by: ElHachem02 <[email protected]> * add vlm_inference time for api calls and track stop reason Signed-off-by: ElHachem02 <[email protected]> * fix: rename enum to VlmStopReason Signed-off-by: ElHachem02 <[email protected]> * Propagate partial success status if page reaches max tokens Signed-off-by: ElHachem02 <[email protected]> * feat: page with generation stopped by loop detector create partial success status Signed-off-by: Peter El Hachem <[email protected]> * Add hint for future improvement Signed-off-by: Peter El Hachem <[email protected]> * fix: remove vlm_stop_reason from extracted page data, add UNSPECIFIED state as VlmStopReason to avoid null value Signed-off-by: Peter El Hachem <[email protected]> --------- Signed-off-by: ElHachem02 <[email protected]> Signed-off-by: Peter El Hachem <[email protected]> Co-authored-by: Peter El Hachem <[email protected]>
1 parent 1a5146a commit 6a04e27

File tree

10 files changed

+92
-52
lines changed

10 files changed

+92
-52
lines changed

docling/datamodel/base_models.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,13 @@ class DoclingComponentType(str, Enum):
166166
USER_INPUT = "user_input"
167167

168168

169+
class VlmStopReason(str, Enum):
170+
LENGTH = "length" # max tokens reached
171+
STOP_SEQUENCE = "stop_sequence" # Custom stopping criteria met
172+
END_OF_SEQUENCE = "end_of_sequence" # Model generated end-of-text token
173+
UNSPECIFIED = "unspecified" # Defaul none value
174+
175+
169176
class ErrorItem(BaseModel):
170177
component_type: DoclingComponentType
171178
module_name: str
@@ -208,7 +215,7 @@ class VlmPrediction(BaseModel):
208215
generated_tokens: list[VlmPredictionToken] = []
209216
generation_time: float = -1
210217
num_tokens: Optional[int] = None
211-
stop_reason: Optional[str] = None # todo define an enum for possible stop reasons
218+
stop_reason: VlmStopReason = VlmStopReason.UNSPECIFIED
212219

213220

214221
class ContainerElement(

docling/datamodel/extraction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from pydantic import BaseModel, Field
66

7-
from docling.datamodel.base_models import ConversionStatus, ErrorItem
7+
from docling.datamodel.base_models import ConversionStatus, ErrorItem, VlmStopReason
88
from docling.datamodel.document import InputDocument
99

1010

docling/models/api_vlm_model.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from transformers import StoppingCriteria
55

6-
from docling.datamodel.base_models import Page, VlmPrediction
6+
from docling.datamodel.base_models import Page, VlmPrediction, VlmStopReason
77
from docling.datamodel.document import ConversionResult
88
from docling.datamodel.pipeline_options_vlm_model import ApiVlmOptions
99
from docling.exceptions import OperationNotAllowed
@@ -59,6 +59,7 @@ def _vlm_request(page):
5959
hi_res_image = hi_res_image.convert("RGB")
6060

6161
prompt = self.vlm_options.build_prompt(page.parsed_page)
62+
stop_reason = VlmStopReason.UNSPECIFIED
6263

6364
if self.vlm_options.custom_stopping_criteria:
6465
# Instantiate any GenerationStopper classes before passing to streaming
@@ -73,29 +74,33 @@ def _vlm_request(page):
7374
# Skip non-GenerationStopper criteria (should have been caught in validation)
7475

7576
# Streaming path with early abort support
76-
page_tags, num_tokens = api_image_request_streaming(
77-
image=hi_res_image,
78-
prompt=prompt,
79-
url=self.vlm_options.url,
80-
timeout=self.timeout,
81-
headers=self.vlm_options.headers,
82-
generation_stoppers=instantiated_stoppers,
83-
**self.params,
84-
)
77+
with TimeRecorder(conv_res, "vlm_inference"):
78+
page_tags, num_tokens = api_image_request_streaming(
79+
image=hi_res_image,
80+
prompt=prompt,
81+
url=self.vlm_options.url,
82+
timeout=self.timeout,
83+
headers=self.vlm_options.headers,
84+
generation_stoppers=instantiated_stoppers,
85+
**self.params,
86+
)
87+
page_tags = self.vlm_options.decode_response(page_tags)
8588
else:
8689
# Non-streaming fallback (existing behavior)
87-
page_tags, num_tokens = api_image_request(
88-
image=hi_res_image,
89-
prompt=prompt,
90-
url=self.vlm_options.url,
91-
timeout=self.timeout,
92-
headers=self.vlm_options.headers,
93-
**self.params,
94-
)
90+
with TimeRecorder(conv_res, "vlm_inference"):
91+
page_tags, num_tokens, stop_reason = api_image_request(
92+
image=hi_res_image,
93+
prompt=prompt,
94+
url=self.vlm_options.url,
95+
timeout=self.timeout,
96+
headers=self.vlm_options.headers,
97+
**self.params,
98+
)
99+
100+
page_tags = self.vlm_options.decode_response(page_tags)
95101

96-
page_tags = self.vlm_options.decode_response(page_tags)
97102
page.predictions.vlm_response = VlmPrediction(
98-
text=page_tags, num_tokens=num_tokens
103+
text=page_tags, num_tokens=num_tokens, stop_reason=stop_reason
99104
)
100105
return page
101106

docling/models/picture_description_api_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
5151
# Note: technically we could make a batch request here,
5252
# but not all APIs will allow for it. For example, vllm won't allow more than 1.
5353
def _api_request(image):
54-
response, _ = api_image_request(
54+
page_tags, _, _ = api_image_request(
5555
image=image,
5656
prompt=self.options.prompt,
5757
url=self.options.url,
@@ -60,7 +60,7 @@ def _api_request(image):
6060
**self.options.params,
6161
)
6262

63-
return response
63+
return page_tags
6464

6565
with ThreadPoolExecutor(max_workers=self.concurrency) as executor:
6666
yield from executor.map(_api_request, images)

docling/models/vlm_models_inline/hf_transformers_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from docling.datamodel.accelerator_options import (
1414
AcceleratorOptions,
1515
)
16-
from docling.datamodel.base_models import Page, VlmPrediction
16+
from docling.datamodel.base_models import Page, VlmPrediction, VlmStopReason
1717
from docling.datamodel.document import ConversionResult
1818
from docling.datamodel.pipeline_options_vlm_model import (
1919
InlineVlmOptions,
@@ -382,4 +382,5 @@ def process_images(
382382
text=decoded_text,
383383
generation_time=generation_time,
384384
num_tokens=num_tokens,
385+
stop_reason=VlmStopReason.UNSPECIFIED,
385386
)

docling/models/vlm_models_inline/mlx_model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
from docling.datamodel.accelerator_options import (
1414
AcceleratorOptions,
1515
)
16-
from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
16+
from docling.datamodel.base_models import (
17+
Page,
18+
VlmPrediction,
19+
VlmPredictionToken,
20+
VlmStopReason,
21+
)
1722
from docling.datamodel.document import ConversionResult
1823
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
1924
from docling.models.base_model import BaseVlmPageModel
@@ -319,5 +324,6 @@ def process_images(
319324
generation_time=generation_time,
320325
generated_tokens=tokens,
321326
num_tokens=len(tokens),
327+
stop_reason=VlmStopReason.UNSPECIFIED,
322328
)
323329
_log.debug("MLX model: Released global lock")

docling/models/vlm_models_inline/nuextract_transformers_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from docling.datamodel.accelerator_options import (
1313
AcceleratorOptions,
1414
)
15-
from docling.datamodel.base_models import VlmPrediction
15+
from docling.datamodel.base_models import VlmPrediction, VlmStopReason
1616
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
1717
from docling.models.base_model import BaseVlmModel
1818
from docling.models.utils.hf_model_download import (
@@ -284,6 +284,7 @@ def process_images(
284284
# Optional logging
285285
num_tokens = None
286286
if generated_ids.shape[0] > 0: # type: ignore
287+
# Todo: confirm num tokens is actually from first item, code was already like this
287288
num_tokens = int(generated_ids[0].shape[0])
288289
_log.debug(
289290
f"Generated {num_tokens} tokens in {generation_time:.2f}s "
@@ -297,4 +298,5 @@ def process_images(
297298
text=decoded_text,
298299
generation_time=generation_time,
299300
num_tokens=num_tokens,
301+
stop_reason=VlmStopReason.UNSPECIFIED,
300302
)

docling/models/vlm_models_inline/vllm_model.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99
from PIL.Image import Image
1010

1111
from docling.datamodel.accelerator_options import AcceleratorOptions
12-
from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
12+
from docling.datamodel.base_models import (
13+
Page,
14+
VlmPrediction,
15+
VlmPredictionToken,
16+
VlmStopReason,
17+
)
1318
from docling.datamodel.document import ConversionResult
1419
from docling.datamodel.pipeline_options_vlm_model import (
1520
InlineVlmOptions,
@@ -311,24 +316,22 @@ def process_images(
311316
# Emit predictions
312317
for output in outputs:
313318
text = output.outputs[0].text if output.outputs else ""
314-
stop_reason = output.outputs[0].stop_reason if output.outputs else ""
315-
generated_tokens = [
316-
VlmPredictionToken(token=int(p)) for p in output.outputs[0].token_ids
317-
]
319+
stop_reason = (
320+
VlmStopReason.END_OF_SEQUENCE
321+
if output.outputs[0].stop_reason
322+
else VlmStopReason.LENGTH
323+
)
324+
generated_tokens = (
325+
[VlmPredictionToken(token=int(t)) for t in output.outputs[0].token_ids]
326+
if self.vlm_options.track_generated_tokens
327+
else []
328+
)
318329
num_tokens = len(generated_tokens)
319330
decoded_text = self.vlm_options.decode_response(text)
320-
if self.vlm_options.track_generated_tokens:
321-
yield VlmPrediction(
322-
text=decoded_text,
323-
generation_time=generation_time,
324-
num_tokens=num_tokens,
325-
stop_reason=stop_reason,
326-
generated_tokens=generated_tokens,
327-
)
328-
else:
329-
yield VlmPrediction(
330-
text=decoded_text,
331-
generation_time=generation_time,
332-
num_tokens=num_tokens,
333-
stop_reason=stop_reason,
334-
)
331+
yield VlmPrediction(
332+
text=decoded_text,
333+
generation_time=generation_time,
334+
num_tokens=num_tokens,
335+
stop_reason=stop_reason,
336+
generated_tokens=generated_tokens,
337+
)

docling/pipeline/extraction_vlm_pipeline.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from docling.backend.abstract_backend import PaginatedDocumentBackend
1010
from docling.backend.pdf_backend import PdfDocumentBackend
11-
from docling.datamodel.base_models import ConversionStatus, ErrorItem
11+
from docling.datamodel.base_models import ConversionStatus, ErrorItem, VlmStopReason
1212
from docling.datamodel.document import InputDocument
1313
from docling.datamodel.extraction import (
1414
ExtractedPageData,
@@ -83,6 +83,12 @@ def _extract_data(
8383
# Parse the extracted text as JSON if possible, otherwise use as-is
8484
extracted_text = predictions[0].text
8585
extracted_data = None
86+
vlm_stop_reason: VlmStopReason = predictions[0].stop_reason
87+
if (
88+
vlm_stop_reason == VlmStopReason.LENGTH
89+
or vlm_stop_reason == VlmStopReason.STOP_SEQUENCE
90+
):
91+
ext_res.status = ConversionStatus.PARTIAL_SUCCESS
8692

8793
try:
8894
extracted_data = json.loads(extracted_text)
@@ -128,7 +134,11 @@ def _extract_data(
128134
def _determine_status(self, ext_res: ExtractionResult) -> ConversionStatus:
129135
"""Determine the status based on extraction results."""
130136
if ext_res.pages and not any(page.errors for page in ext_res.pages):
131-
return ConversionStatus.SUCCESS
137+
return (
138+
ConversionStatus.PARTIAL_SUCCESS
139+
if ext_res.status == ConversionStatus.PARTIAL_SUCCESS
140+
else ConversionStatus.SUCCESS
141+
)
132142
else:
133143
return ConversionStatus.FAILURE
134144

docling/utils/api_image_request.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from PIL import Image
99
from pydantic import AnyUrl
1010

11-
from docling.datamodel.base_models import OpenAiApiResponse
11+
from docling.datamodel.base_models import OpenAiApiResponse, VlmStopReason
1212
from docling.models.utils.generation_utils import GenerationStopper
1313

1414
_log = logging.getLogger(__name__)
@@ -21,7 +21,7 @@ def api_image_request(
2121
timeout: float = 20,
2222
headers: Optional[dict[str, str]] = None,
2323
**params,
24-
) -> Tuple[str, Optional[int]]:
24+
) -> Tuple[str, Optional[int], VlmStopReason]:
2525
img_io = BytesIO()
2626
image.save(img_io, "PNG")
2727
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
@@ -61,7 +61,13 @@ def api_image_request(
6161
api_resp = OpenAiApiResponse.model_validate_json(r.text)
6262
generated_text = api_resp.choices[0].message.content.strip()
6363
num_tokens = api_resp.usage.total_tokens
64-
return generated_text, num_tokens
64+
stop_reason = (
65+
VlmStopReason.LENGTH
66+
if api_resp.choices[0].finish_reason == "length"
67+
else VlmStopReason.END_OF_SEQUENCE
68+
)
69+
70+
return generated_text, num_tokens, stop_reason
6571

6672

6773
def api_image_request_streaming(

0 commit comments

Comments
 (0)