|
9 | 9 | from PIL.Image import Image |
10 | 10 |
|
11 | 11 | from docling.datamodel.accelerator_options import AcceleratorOptions |
12 | | -from docling.datamodel.base_models import Page, VlmPrediction |
| 12 | +from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken |
13 | 13 | from docling.datamodel.document import ConversionResult |
14 | 14 | from docling.datamodel.pipeline_options_vlm_model import ( |
15 | 15 | InlineVlmOptions, |
@@ -88,7 +88,7 @@ def __init__( |
88 | 88 | vlm_options: InlineVlmOptions, |
89 | 89 | ): |
90 | 90 | self.enabled = enabled |
91 | | - self.vlm_options = vlm_options |
| 91 | + self.vlm_options: InlineVlmOptions = vlm_options |
92 | 92 |
|
93 | 93 | self.llm = None |
94 | 94 | self.sampling_params = None |
@@ -234,7 +234,8 @@ def __call__( |
234 | 234 | pages_with_images.append(page) |
235 | 235 |
|
236 | 236 | if images: |
237 | | - predictions = list(self.process_images(images, user_prompts)) |
| 237 | + with TimeRecorder(conv_res, "vlm_inference"): |
| 238 | + predictions = list(self.process_images(images, user_prompts)) |
238 | 239 | for page, prediction in zip(pages_with_images, predictions): |
239 | 240 | page.predictions.vlm_response = prediction |
240 | 241 |
|
@@ -300,13 +301,34 @@ def process_images( |
300 | 301 | # Optional debug |
301 | 302 | if outputs: |
302 | 303 | try: |
303 | | - num_tokens = len(outputs[0].outputs[0].token_ids) |
304 | | - _log.debug(f"Generated {num_tokens} tokens in {generation_time:.2f}s.") |
| 304 | + num_tokens_within_batch = len(outputs[0].outputs[0].token_ids) |
| 305 | + _log.debug( |
| 306 | + f"Generated {num_tokens_within_batch} tokens for batch in {generation_time:.2f}s." |
| 307 | + ) |
305 | 308 | except Exception: |
306 | | - pass |
| 309 | + num_tokens_within_batch = 0 |
307 | 310 |
|
308 | 311 | # Emit predictions |
309 | 312 | for output in outputs: |
310 | 313 | text = output.outputs[0].text if output.outputs else "" |
| 314 | + stop_reason = output.outputs[0].stop_reason if output.outputs else "" |
| 315 | + generated_tokens = [ |
| 316 | + VlmPredictionToken(token=int(p)) for p in output.outputs[0].token_ids |
| 317 | + ] |
| 318 | + num_tokens = len(generated_tokens) |
311 | 319 | decoded_text = self.vlm_options.decode_response(text) |
312 | | - yield VlmPrediction(text=decoded_text, generation_time=generation_time) |
| 320 | + if self.vlm_options.track_generated_tokens: |
| 321 | + yield VlmPrediction( |
| 322 | + text=decoded_text, |
| 323 | + generation_time=generation_time, |
| 324 | + num_tokens=num_tokens, |
| 325 | + stop_reason=stop_reason, |
| 326 | + generated_tokens=generated_tokens, |
| 327 | + ) |
| 328 | + else: |
| 329 | + yield VlmPrediction( |
| 330 | + text=decoded_text, |
| 331 | + generation_time=generation_time, |
| 332 | + num_tokens=num_tokens, |
| 333 | + stop_reason=stop_reason, |
| 334 | + ) |
0 commit comments