Skip to content

Commit b6c892b

Browse files
ElHachem02Peter El Hachem
andauthored
feat(vlm): add num_tokens as attribtue for VlmPrediction (#2489)
* feat: add num_tokens as attribtue for VlmPrediction * feat: implement tokens tracking for api_vlm Signed-off-by: Peter El Hachem <[email protected]> * DCO Remediation Commit for ElHachem02 <[email protected]> I, ElHachem02 <[email protected]>, hereby add my Signed-off-by to this commit: 311287f Signed-off-by: Peter El Hachem <[email protected]> * DCO Remediation Commit for ElHachem02 <[email protected]> I, ElHachem02 <[email protected]>, hereby add my Signed-off-by to this commit: 311287f Signed-off-by: ElHachem02 <[email protected]> * update return type Signed-off-by: ElHachem02 <[email protected]> * add time recorder for vlm inference and track generated token ids depending on config Signed-off-by: ElHachem02 <[email protected]> * update num_tokens to have None as value on exception Signed-off-by: ElHachem02 <[email protected]> * set default value of num_tokens to None Signed-off-by: ElHachem02 <[email protected]> --------- Signed-off-by: Peter El Hachem <[email protected]> Signed-off-by: ElHachem02 <[email protected]> Signed-off-by: peets <[email protected]> Co-authored-by: Peter El Hachem <[email protected]>
1 parent cdffb47 commit b6c892b

File tree

8 files changed

+71
-20
lines changed

8 files changed

+71
-20
lines changed

docling/datamodel/base_models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,8 @@ class VlmPrediction(BaseModel):
207207
text: str = ""
208208
generated_tokens: list[VlmPredictionToken] = []
209209
generation_time: float = -1
210+
num_tokens: Optional[int] = None
211+
stop_reason: Optional[str] = None # todo define an enum for possible stop reasons
210212

211213

212214
class ContainerElement(

docling/datamodel/pipeline_options_vlm_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ class InlineVlmOptions(BaseVlmOptions):
8282

8383
use_kv_cache: bool = True
8484
max_new_tokens: int = 4096
85+
track_generated_tokens: bool = False
8586

8687
@property
8788
def repo_cache_folder(self) -> str:

docling/models/api_vlm_model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def _vlm_request(page):
7373
# Skip non-GenerationStopper criteria (should have been caught in validation)
7474

7575
# Streaming path with early abort support
76-
page_tags = api_image_request_streaming(
76+
page_tags, num_tokens = api_image_request_streaming(
7777
image=hi_res_image,
7878
prompt=prompt,
7979
url=self.vlm_options.url,
@@ -84,7 +84,7 @@ def _vlm_request(page):
8484
)
8585
else:
8686
# Non-streaming fallback (existing behavior)
87-
page_tags = api_image_request(
87+
page_tags, num_tokens = api_image_request(
8888
image=hi_res_image,
8989
prompt=prompt,
9090
url=self.vlm_options.url,
@@ -94,7 +94,9 @@ def _vlm_request(page):
9494
)
9595

9696
page_tags = self.vlm_options.decode_response(page_tags)
97-
page.predictions.vlm_response = VlmPrediction(text=page_tags)
97+
page.predictions.vlm_response = VlmPrediction(
98+
text=page_tags, num_tokens=num_tokens
99+
)
98100
return page
99101

100102
with ThreadPoolExecutor(max_workers=self.concurrency) as executor:

docling/models/vlm_models_inline/hf_transformers_model.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,13 +367,19 @@ def process_images(
367367
decoded_texts = [text.rstrip(pad_token) for text in decoded_texts]
368368

369369
# -- Optional logging
370+
num_tokens = None
370371
if generated_ids.shape[0] > 0:
372+
num_tokens = int(generated_ids[0].shape[0])
371373
_log.debug(
372-
f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s "
374+
f"Generated {num_tokens} tokens in {generation_time:.2f}s "
373375
f"for batch size {generated_ids.shape[0]}."
374376
)
375377

376378
for text in decoded_texts:
377379
# Apply decode_response to the output text
378380
decoded_text = self.vlm_options.decode_response(text)
379-
yield VlmPrediction(text=decoded_text, generation_time=generation_time)
381+
yield VlmPrediction(
382+
text=decoded_text,
383+
generation_time=generation_time,
384+
num_tokens=num_tokens,
385+
)

docling/models/vlm_models_inline/mlx_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,5 +318,6 @@ def process_images(
318318
text=decoded_output,
319319
generation_time=generation_time,
320320
generated_tokens=tokens,
321+
num_tokens=len(tokens),
321322
)
322323
_log.debug("MLX model: Released global lock")

docling/models/vlm_models_inline/nuextract_transformers_model.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,13 +282,19 @@ def process_images(
282282
)
283283

284284
# Optional logging
285+
num_tokens = None
285286
if generated_ids.shape[0] > 0: # type: ignore
287+
num_tokens = int(generated_ids[0].shape[0])
286288
_log.debug(
287-
f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s "
289+
f"Generated {num_tokens} tokens in {generation_time:.2f}s "
288290
f"for batch size {generated_ids.shape[0]}." # type: ignore
289291
)
290292

291293
for text in decoded_texts:
292294
# Apply decode_response to the output text
293295
decoded_text = self.vlm_options.decode_response(text)
294-
yield VlmPrediction(text=decoded_text, generation_time=generation_time)
296+
yield VlmPrediction(
297+
text=decoded_text,
298+
generation_time=generation_time,
299+
num_tokens=num_tokens,
300+
)

docling/models/vlm_models_inline/vllm_model.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from PIL.Image import Image
1010

1111
from docling.datamodel.accelerator_options import AcceleratorOptions
12-
from docling.datamodel.base_models import Page, VlmPrediction
12+
from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
1313
from docling.datamodel.document import ConversionResult
1414
from docling.datamodel.pipeline_options_vlm_model import (
1515
InlineVlmOptions,
@@ -88,7 +88,7 @@ def __init__(
8888
vlm_options: InlineVlmOptions,
8989
):
9090
self.enabled = enabled
91-
self.vlm_options = vlm_options
91+
self.vlm_options: InlineVlmOptions = vlm_options
9292

9393
self.llm = None
9494
self.sampling_params = None
@@ -234,7 +234,8 @@ def __call__(
234234
pages_with_images.append(page)
235235

236236
if images:
237-
predictions = list(self.process_images(images, user_prompts))
237+
with TimeRecorder(conv_res, "vlm_inference"):
238+
predictions = list(self.process_images(images, user_prompts))
238239
for page, prediction in zip(pages_with_images, predictions):
239240
page.predictions.vlm_response = prediction
240241

@@ -300,13 +301,34 @@ def process_images(
300301
# Optional debug
301302
if outputs:
302303
try:
303-
num_tokens = len(outputs[0].outputs[0].token_ids)
304-
_log.debug(f"Generated {num_tokens} tokens in {generation_time:.2f}s.")
304+
num_tokens_within_batch = len(outputs[0].outputs[0].token_ids)
305+
_log.debug(
306+
f"Generated {num_tokens_within_batch} tokens for batch in {generation_time:.2f}s."
307+
)
305308
except Exception:
306-
pass
309+
num_tokens_within_batch = 0
307310

308311
# Emit predictions
309312
for output in outputs:
310313
text = output.outputs[0].text if output.outputs else ""
314+
stop_reason = output.outputs[0].stop_reason if output.outputs else ""
315+
generated_tokens = [
316+
VlmPredictionToken(token=int(p)) for p in output.outputs[0].token_ids
317+
]
318+
num_tokens = len(generated_tokens)
311319
decoded_text = self.vlm_options.decode_response(text)
312-
yield VlmPrediction(text=decoded_text, generation_time=generation_time)
320+
if self.vlm_options.track_generated_tokens:
321+
yield VlmPrediction(
322+
text=decoded_text,
323+
generation_time=generation_time,
324+
num_tokens=num_tokens,
325+
stop_reason=stop_reason,
326+
generated_tokens=generated_tokens,
327+
)
328+
else:
329+
yield VlmPrediction(
330+
text=decoded_text,
331+
generation_time=generation_time,
332+
num_tokens=num_tokens,
333+
stop_reason=stop_reason,
334+
)

docling/utils/api_image_request.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import json
33
import logging
44
from io import BytesIO
5-
from typing import Optional
5+
from typing import Dict, List, Optional, Tuple
66

77
import requests
88
from PIL import Image
@@ -21,7 +21,7 @@ def api_image_request(
2121
timeout: float = 20,
2222
headers: Optional[dict[str, str]] = None,
2323
**params,
24-
) -> str:
24+
) -> Tuple[str, Optional[int]]:
2525
img_io = BytesIO()
2626
image.save(img_io, "PNG")
2727
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
@@ -60,7 +60,8 @@ def api_image_request(
6060

6161
api_resp = OpenAiApiResponse.model_validate_json(r.text)
6262
generated_text = api_resp.choices[0].message.content.strip()
63-
return generated_text
63+
num_tokens = api_resp.usage.total_tokens
64+
return generated_text, num_tokens
6465

6566

6667
def api_image_request_streaming(
@@ -72,7 +73,7 @@ def api_image_request_streaming(
7273
headers: Optional[dict[str, str]] = None,
7374
generation_stoppers: list[GenerationStopper] = [],
7475
**params,
75-
) -> str:
76+
) -> Tuple[str, Optional[int]]:
7677
"""
7778
Stream a chat completion from an OpenAI-compatible server (e.g., vLLM).
7879
Parses SSE lines: 'data: {json}\\n\\n', terminated by 'data: [DONE]'.
@@ -150,6 +151,16 @@ def api_image_request_streaming(
150151
_log.debug("Unexpected SSE chunk shape: %s", e)
151152
piece = ""
152153

154+
# Try to extract token count
155+
num_tokens = None
156+
try:
157+
if "usage" in obj:
158+
usage = obj["usage"]
159+
num_tokens = usage.get("total_tokens")
160+
except Exception as e:
161+
num_tokens = None
162+
_log.debug("Usage key not included in response: %s", e)
163+
153164
if piece:
154165
full_text.append(piece)
155166
for stopper in generation_stoppers:
@@ -162,6 +173,6 @@ def api_image_request_streaming(
162173
# closing the connection when we exit the 'with' block.
163174
# vLLM/OpenAI-compatible servers will detect the client disconnect
164175
# and abort the request server-side.
165-
return "".join(full_text)
176+
return "".join(full_text), num_tokens
166177

167-
return "".join(full_text)
178+
return "".join(full_text), num_tokens

0 commit comments

Comments
 (0)