@@ -336,12 +336,23 @@ def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
336336 rejected_type_ids = processed_rejecteds ["token_type_ids" ]
337337 completion_token_type_ids = torch .cat (tuple (pad ([chosen_type_ids , rejected_type_ids ], padding_value = 0 )))
338338 token_type_ids = torch .cat ((prompt_token_type_ids , completion_token_type_ids ), dim = 1 )
339+ if "mm_token_type_ids" in processed_prompts : # special case for Qwen2.5-VL
340+ prompt_mm_token_type_ids = processed_prompts ["mm_token_type_ids" ]
341+ mm_token_type_ids = torch .cat ((prompt_mm_token_type_ids , torch .zeros_like (completion_ids )), dim = 1 )
339342
340343 # Flush left to reduce padding
341- if "token_type_ids" in processed_prompts :
344+ if "token_type_ids" in processed_prompts and "mm_token_type_ids" in processed_prompts :
345+ attention_mask , input_ids , completion_mask , token_type_ids , mm_token_type_ids = flush_left (
346+ attention_mask , input_ids , completion_mask , token_type_ids , mm_token_type_ids
347+ )
348+ elif "token_type_ids" in processed_prompts :
342349 attention_mask , input_ids , completion_mask , token_type_ids = flush_left (
343350 attention_mask , input_ids , completion_mask , token_type_ids
344351 )
352+ elif "mm_token_type_ids" in processed_prompts :
353+ attention_mask , input_ids , completion_mask , mm_token_type_ids = flush_left (
354+ attention_mask , input_ids , completion_mask , mm_token_type_ids
355+ )
345356 else :
346357 attention_mask , input_ids , completion_mask = flush_left (attention_mask , input_ids , completion_mask )
347358
@@ -352,6 +363,8 @@ def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
352363 output ["completion_mask" ] = completion_mask
353364 if "token_type_ids" in processed_prompts :
354365 output ["token_type_ids" ] = token_type_ids
366+ if "mm_token_type_ids" in processed_prompts :
367+ output ["mm_token_type_ids" ] = mm_token_type_ids
355368 return output
356369
357370
@@ -992,7 +1005,14 @@ def compute_ref_log_probs(self, inputs):
9921005 shift_completion_mask = completion_mask [..., 1 :].contiguous ()
9931006
9941007 model_kwargs = {"input_ids" : input_ids , "attention_mask" : attention_mask , "use_cache" : False }
995- for key in ("pixel_values" , "pixel_attention_mask" , "image_grid_thw" , "image_sizes" , "token_type_ids" ):
1008+ for key in (
1009+ "pixel_values" ,
1010+ "pixel_attention_mask" ,
1011+ "image_grid_thw" ,
1012+ "image_sizes" ,
1013+ "token_type_ids" ,
1014+ "mm_token_type_ids" ,
1015+ ):
9961016 if key in inputs :
9971017 model_kwargs [key ] = inputs [key ]
9981018
@@ -1113,7 +1133,14 @@ def _compute_loss(self, model, inputs, return_outputs):
11131133 input_ids , attention_mask , completion_mask = self ._truncate_inputs (input_ids , attention_mask , completion_mask )
11141134
11151135 model_kwargs = {"input_ids" : input_ids , "attention_mask" : attention_mask , "use_cache" : False }
1116- for key in ("pixel_values" , "pixel_attention_mask" , "image_grid_thw" , "image_sizes" , "token_type_ids" ):
1136+ for key in (
1137+ "pixel_values" ,
1138+ "pixel_attention_mask" ,
1139+ "image_grid_thw" ,
1140+ "image_sizes" ,
1141+ "token_type_ids" ,
1142+ "mm_token_type_ids" ,
1143+ ):
11171144 if key in inputs :
11181145 model_kwargs [key ] = inputs [key ]
11191146
0 commit comments