Skip to content

Commit 493180c

Browse files
committed
feat(models): enhance HuggingFaceTransformersVlmModel with improved handling
1 parent a5af082 commit 493180c

File tree

1 file changed

+98
-19
lines changed

1 file changed

+98
-19
lines changed

docling/models/vlm_models_inline/hf_transformers_model.py

Lines changed: 98 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -114,20 +114,59 @@ def __init__(
114114
trust_remote_code=vlm_options.trust_remote_code,
115115
revision=vlm_options.revision,
116116
)
117-
self.processor.tokenizer.padding_side = "left"
117+
118+
# Set padding side for tokenizer, handling different processor types
119+
if hasattr(self.processor, "tokenizer"):
120+
self.processor.tokenizer.padding_side = "left"
121+
elif hasattr(self.processor, "_tokenizer"):
122+
self.processor._tokenizer.padding_side = "left"
123+
else:
124+
# Some processors might be different; try to find the tokenizer attribute
125+
tokenizer_attrs = ["tokenizer", "_tokenizer", "text_processor"]
126+
for attr in tokenizer_attrs:
127+
if hasattr(self.processor, attr):
128+
tokenizer = getattr(self.processor, attr)
129+
if hasattr(tokenizer, "padding_side"):
130+
tokenizer.padding_side = "left"
131+
break
132+
133+
# Determine attention implementation, with fallback for models that don't support certain implementations
134+
attn_implementation = "sdpa"
135+
if self.device.startswith("cuda") and accelerator_options.cuda_use_flash_attention2:
136+
attn_implementation = "flash_attention_2"
137+
138+
# Override with user-specified attention implementation if provided
139+
extra_attn_impl = self.vlm_options.extra_generation_config.get("_attn_implementation")
140+
if extra_attn_impl:
141+
attn_implementation = extra_attn_impl
142+
143+
# Handle device_map - it should be passed during model loading, not generation
144+
model_loading_kwargs = {
145+
"device_map": self.device, # Use accelerator device as default
146+
"dtype": self.vlm_options.torch_dtype,
147+
"_attn_implementation": attn_implementation,
148+
"trust_remote_code": vlm_options.trust_remote_code,
149+
"revision": vlm_options.revision,
150+
}
151+
152+
# Check if user specified a custom device_map in extra_generation_config
153+
if "device_map" in self.vlm_options.extra_generation_config:
154+
model_loading_kwargs["device_map"] = self.vlm_options.extra_generation_config["device_map"]
155+
# Remove it from extra_generation_config to prevent it being passed during generation
156+
# We need to create a copy and exclude device_map
157+
filtered_extra_config = {
158+
k: v for k, v in self.vlm_options.extra_generation_config.items()
159+
if k != "device_map"
160+
}
161+
# Update the vlm_options with the filtered config
162+
import copy
163+
temp_options = copy.copy(self.vlm_options)
164+
temp_options.extra_generation_config = filtered_extra_config
165+
self.vlm_options = temp_options
118166

119167
self.vlm_model = model_cls.from_pretrained(
120168
artifacts_path,
121-
device_map=self.device,
122-
dtype=self.vlm_options.torch_dtype,
123-
_attn_implementation=(
124-
"flash_attention_2"
125-
if self.device.startswith("cuda")
126-
and accelerator_options.cuda_use_flash_attention2
127-
else "sdpa"
128-
),
129-
trust_remote_code=vlm_options.trust_remote_code,
130-
revision=vlm_options.revision,
169+
**model_loading_kwargs,
131170
)
132171
self.vlm_model = torch.compile(self.vlm_model) # type: ignore
133172

@@ -237,6 +276,7 @@ def process_images(
237276

238277
# Use your prompt formatter verbatim
239278
if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.NONE:
279+
# For models that don't use prompt styles, pass images directly
240280
inputs = self.processor(
241281
pil_images,
242282
return_tensors="pt",
@@ -247,13 +287,48 @@ def process_images(
247287
prompts: list[str] = [self.formulate_prompt(p) for p in user_prompts]
248288

249289
# -- Processor performs BOTH text+image preprocessing + batch padding (recommended)
250-
inputs = self.processor(
251-
text=prompts,
252-
images=pil_images,
253-
return_tensors="pt",
254-
padding=True, # pad across batch for both text and vision
255-
**self.vlm_options.extra_processor_kwargs,
256-
)
290+
# For models that may have specific batch requirements, handle accordingly
291+
try:
292+
inputs = self.processor(
293+
text=prompts,
294+
images=pil_images,
295+
return_tensors="pt",
296+
padding=True, # pad across batch for both text and vision
297+
**self.vlm_options.extra_processor_kwargs,
298+
)
299+
except ValueError as e:
300+
if "Received inconsistently sized batches" in str(e):
301+
# Handle models that expect one-to-one image-text pairing
302+
# This is a common case where each image needs its own text
303+
_log.warning(f"Processing with image-text pairing due to: {e}")
304+
# Process each image-text pair separately and combine if needed
305+
all_inputs = []
306+
for img, prompt in zip(pil_images, prompts):
307+
single_input = self.processor(
308+
text=prompt,
309+
images=img, # Single image for single text
310+
return_tensors="pt",
311+
**self.vlm_options.extra_processor_kwargs,
312+
)
313+
all_inputs.append(single_input)
314+
315+
# Combine the inputs - this is a simplified approach
316+
# More complex logic might be needed depending on specific model requirements
317+
if len(all_inputs) == 1:
318+
inputs = all_inputs[0]
319+
else:
320+
# For multiple inputs, we'll use the first one as base and update with batched tensors
321+
inputs = {}
322+
for key in all_inputs[0].keys():
323+
# Stack tensors from each input
324+
stacked_tensors = []
325+
for single_input in all_inputs:
326+
stacked_tensors.append(single_input[key])
327+
# Concatenate along batch dimension (dim=0)
328+
import torch
329+
inputs[key] = torch.cat(stacked_tensors, dim=0)
330+
else:
331+
raise
257332
inputs = {k: v.to(self.device) for k, v in inputs.items()}
258333

259334
# -- Optional stopping criteria
@@ -307,10 +382,14 @@ def process_images(
307382
"clean_up_tokenization_spaces",
308383
"spaces_between_special_tokens",
309384
}
385+
# Also filter out model loading specific keys that shouldn't be passed to generation
386+
model_loading_keys = {
387+
"_attn_implementation", # This is for model loading, not generation
388+
}
310389
generation_config = {
311390
k: v
312391
for k, v in self.vlm_options.extra_generation_config.items()
313-
if k not in decoder_keys
392+
if k not in decoder_keys and k not in model_loading_keys
314393
}
315394
decoder_config = {
316395
k: v

0 commit comments

Comments
 (0)