@@ -114,20 +114,59 @@ def __init__(
114114 trust_remote_code = vlm_options .trust_remote_code ,
115115 revision = vlm_options .revision ,
116116 )
117- self .processor .tokenizer .padding_side = "left"
117+
118+ # Set padding side for tokenizer, handling different processor types
119+ if hasattr (self .processor , "tokenizer" ):
120+ self .processor .tokenizer .padding_side = "left"
121+ elif hasattr (self .processor , "_tokenizer" ):
122+ self .processor ._tokenizer .padding_side = "left"
123+ else :
124+ # Some processors might be different; try to find the tokenizer attribute
125+ tokenizer_attrs = ["tokenizer" , "_tokenizer" , "text_processor" ]
126+ for attr in tokenizer_attrs :
127+ if hasattr (self .processor , attr ):
128+ tokenizer = getattr (self .processor , attr )
129+ if hasattr (tokenizer , "padding_side" ):
130+ tokenizer .padding_side = "left"
131+ break
132+
133+ # Determine attention implementation, with fallback for models that don't support certain implementations
134+ attn_implementation = "sdpa"
135+ if self .device .startswith ("cuda" ) and accelerator_options .cuda_use_flash_attention2 :
136+ attn_implementation = "flash_attention_2"
137+
138+ # Override with user-specified attention implementation if provided
139+ extra_attn_impl = self .vlm_options .extra_generation_config .get ("_attn_implementation" )
140+ if extra_attn_impl :
141+ attn_implementation = extra_attn_impl
142+
143+ # Handle device_map - it should be passed during model loading, not generation
144+ model_loading_kwargs = {
145+ "device_map" : self .device , # Use accelerator device as default
146+ "dtype" : self .vlm_options .torch_dtype ,
147+ "_attn_implementation" : attn_implementation ,
148+ "trust_remote_code" : vlm_options .trust_remote_code ,
149+ "revision" : vlm_options .revision ,
150+ }
151+
152+ # Check if user specified a custom device_map in extra_generation_config
153+ if "device_map" in self .vlm_options .extra_generation_config :
154+ model_loading_kwargs ["device_map" ] = self .vlm_options .extra_generation_config ["device_map" ]
155+ # Remove it from extra_generation_config to prevent it being passed during generation
156+ # We need to create a copy and exclude device_map
157+ filtered_extra_config = {
158+ k : v for k , v in self .vlm_options .extra_generation_config .items ()
159+ if k != "device_map"
160+ }
161+ # Update the vlm_options with the filtered config
162+ import copy
163+ temp_options = copy .copy (self .vlm_options )
164+ temp_options .extra_generation_config = filtered_extra_config
165+ self .vlm_options = temp_options
118166
119167 self .vlm_model = model_cls .from_pretrained (
120168 artifacts_path ,
121- device_map = self .device ,
122- dtype = self .vlm_options .torch_dtype ,
123- _attn_implementation = (
124- "flash_attention_2"
125- if self .device .startswith ("cuda" )
126- and accelerator_options .cuda_use_flash_attention2
127- else "sdpa"
128- ),
129- trust_remote_code = vlm_options .trust_remote_code ,
130- revision = vlm_options .revision ,
169+ ** model_loading_kwargs ,
131170 )
132171 self .vlm_model = torch .compile (self .vlm_model ) # type: ignore
133172
@@ -237,6 +276,7 @@ def process_images(
237276
238277 # Use your prompt formatter verbatim
239278 if self .vlm_options .transformers_prompt_style == TransformersPromptStyle .NONE :
279+ # For models that don't use prompt styles, pass images directly
240280 inputs = self .processor (
241281 pil_images ,
242282 return_tensors = "pt" ,
@@ -247,13 +287,48 @@ def process_images(
247287 prompts : list [str ] = [self .formulate_prompt (p ) for p in user_prompts ]
248288
249289 # -- Processor performs BOTH text+image preprocessing + batch padding (recommended)
250- inputs = self .processor (
251- text = prompts ,
252- images = pil_images ,
253- return_tensors = "pt" ,
254- padding = True , # pad across batch for both text and vision
255- ** self .vlm_options .extra_processor_kwargs ,
256- )
290+ # For models that may have specific batch requirements, handle accordingly
291+ try :
292+ inputs = self .processor (
293+ text = prompts ,
294+ images = pil_images ,
295+ return_tensors = "pt" ,
296+ padding = True , # pad across batch for both text and vision
297+ ** self .vlm_options .extra_processor_kwargs ,
298+ )
299+ except ValueError as e :
300+ if "Received inconsistently sized batches" in str (e ):
301+ # Handle models that expect one-to-one image-text pairing
302+ # This is a common case where each image needs its own text
303+ _log .warning (f"Processing with image-text pairing due to: { e } " )
304+ # Process each image-text pair separately and combine if needed
305+ all_inputs = []
306+ for img , prompt in zip (pil_images , prompts ):
307+ single_input = self .processor (
308+ text = prompt ,
309+ images = img , # Single image for single text
310+ return_tensors = "pt" ,
311+ ** self .vlm_options .extra_processor_kwargs ,
312+ )
313+ all_inputs .append (single_input )
314+
315+ # Combine the inputs - this is a simplified approach
316+ # More complex logic might be needed depending on specific model requirements
317+ if len (all_inputs ) == 1 :
318+ inputs = all_inputs [0 ]
319+ else :
320+ # For multiple inputs, we'll use the first one as base and update with batched tensors
321+ inputs = {}
322+ for key in all_inputs [0 ].keys ():
323+ # Stack tensors from each input
324+ stacked_tensors = []
325+ for single_input in all_inputs :
326+ stacked_tensors .append (single_input [key ])
327+ # Concatenate along batch dimension (dim=0)
328+ import torch
329+ inputs [key ] = torch .cat (stacked_tensors , dim = 0 )
330+ else :
331+ raise
257332 inputs = {k : v .to (self .device ) for k , v in inputs .items ()}
258333
259334 # -- Optional stopping criteria
@@ -307,10 +382,14 @@ def process_images(
307382 "clean_up_tokenization_spaces" ,
308383 "spaces_between_special_tokens" ,
309384 }
385+ # Also filter out model loading specific keys that shouldn't be passed to generation
386+ model_loading_keys = {
387+ "_attn_implementation" , # This is for model loading, not generation
388+ }
310389 generation_config = {
311390 k : v
312391 for k , v in self .vlm_options .extra_generation_config .items ()
313- if k not in decoder_keys
392+ if k not in decoder_keys and k not in model_loading_keys
314393 }
315394 decoder_config = {
316395 k : v
0 commit comments