2020import sys
2121import traceback
2222from dataclasses import dataclass , field
23- from typing import Dict , Optional
23+ from typing import Dict , Optional , Sequence , Any
2424
2525import numpy as np
2626import paddle
4242 Qwen2VLImageProcessor ,
4343 Qwen2VLProcessor ,
4444)
45+ from paddlenlp .transformers .processing_utils import ProcessorMixin
4546
4647Image .MAX_IMAGE_PIXELS = None
4748ImageFile .LOAD_TRUNCATED_IMAGES = True
4849MaximumDecompressedSize = 1024
4950MegaByte = 2 ** 20
5051PngImagePlugin .MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
51-
5252logger = logging .getLogger (__name__ )
5353
5454
@@ -303,8 +303,6 @@ def multi_modal_get_item(self, data_item):
303303
304304 # Merge the image path
305305 image_path = self .get_image_path (data_item ["images" ][0 ]) # TODO: now only single image
306- image = self .load_image (image_path )
307- image_data_dict = transform (image )
308306
309307 messages = data_item ["messages" ]
310308
@@ -328,19 +326,11 @@ def multi_modal_get_item(self, data_item):
328326 input_ids = input_ids ,
329327 labels = labels ,
330328 attention_mask = attention_mask ,
331- pixel_values = image_data_dict ["pixel_values" ],
332- image_grid_thw = image_data_dict ["image_grid_thw" ][0 ],
329+ images = [image_path ],
333330 )
334331 return ret
335332
336333 def pure_text_get_item (self , data_item ):
337- # Build transformation function
338- transform = self .get_transform ()
339-
340- # Create a blank white image
341- image = Image .new ("RGB" , (224 , 224 ), (255 , 255 , 255 ))
342- image_data_dict = transform (image )
343-
344334 messages = data_item ["messages" ]
345335
346336 input_ids , labels = _encode_supervised_example (
@@ -363,9 +353,9 @@ def pure_text_get_item(self, data_item):
363353 input_ids = input_ids ,
364354 labels = labels ,
365355 attention_mask = attention_mask ,
366- pixel_values = image_data_dict ["pixel_values" ],
367- image_grid_thw = image_data_dict ["image_grid_thw" ][0 ],
356+ images = [],
368357 )
358+
369359 return ret
370360
371361 def __getitem__ (self , i ) -> Dict [str , paddle .Tensor ]:
@@ -374,10 +364,6 @@ def __getitem__(self, i) -> Dict[str, paddle.Tensor]:
374364 try :
375365 data_item = self .raw_data [i ]
376366 if "images" in data_item and len (data_item ["images" ]) != 0 :
377- # if type(data_item['images']) == list:
378- # ret = self.multi_modal_multi_image_get_item(data_item)
379- # else:
380- # ret = self.multi_modal_get_item(data_item)
381367 ret = self .multi_modal_get_item (data_item ) # TODO: 暂时都是单图
382368 else :
383369 ret = self .pure_text_get_item (data_item ) # TODO: 纯文
@@ -457,6 +443,94 @@ def print_trainable_params(model: paddle.nn.Layer) -> None:
457443 )
458444
459445
446+ @dataclass
447+ class MultiModalDataCollatorForSeq2Seq (DataCollatorForSeq2Seq ):
448+ r"""
449+ Data collator that supports VLMs.
450+
451+ Features should contain input_ids, attention_mask, labels, and optionally contain images and videos.
452+ """
453+
454+ template : Optional ["TEMPLATES" ] = None
455+ processor : Optional ["ProcessorMixin" ] = None
456+
457+ def __post_init__ (self ):
458+ if self .template is None :
459+ raise ValueError ("Template is required for MultiModalDataCollator." )
460+
461+ def __call__ (self , features : Sequence [Dict [str , Any ]]) -> Dict [str , "paddle.Tensor" ]:
462+ batch_images , batch_videos , batch_imglens , batch_vidlens , batch_input_ids = [], [], [], [], []
463+
464+ for feature in features :
465+ images = feature .pop ("images" , None ) or []
466+ videos = feature .pop ("videos" , None ) or []
467+ batch_images .extend (images )
468+ batch_videos .extend (videos )
469+ batch_imglens .append (len (images ))
470+ batch_vidlens .append (len (videos ))
471+ batch_input_ids .append (feature ["input_ids" ])
472+
473+ if (
474+ self .processor is not None and sum (batch_imglens ) == 0 and sum (batch_vidlens ) == 0
475+ ):
476+ fake_messages = [{"role" : "user" , "content" : IMAGE_PLACEHOLDER }]
477+ fake_images = [Image .new ("RGB" , (64 , 64 ), (255 , 255 , 255 ))]
478+ fake_messages = self .template .mm_plugin .process_messages (fake_messages , fake_images , [], self .processor )
479+ fake_input_ids = self .tokenizer .encode (fake_messages [0 ]["content" ], add_special_tokens = False )
480+ fake_input_ids , _ = self .template .mm_plugin .process_token_ids (
481+ fake_input_ids , None , fake_images , [], self .tokenizer , self .processor
482+ )
483+
484+ if self .tokenizer .padding_side == "right" :
485+ features [0 ]["input_ids" ] = features [0 ]["input_ids" ] + fake_input_ids
486+ features [0 ]["attention_mask" ] = features [0 ]["attention_mask" ] + [0 ] * len (fake_input_ids )
487+ features [0 ]["labels" ] = features [0 ]["labels" ] + [IGNORE_INDEX ] * len (fake_input_ids )
488+ else :
489+ features [0 ]["input_ids" ] = fake_input_ids + features [0 ]["input_ids" ]
490+ features [0 ]["attention_mask" ] = [0 ] * len (fake_input_ids ) + features [0 ]["attention_mask" ]
491+ features [0 ]["labels" ] = [IGNORE_INDEX ] * len (fake_input_ids ) + features [0 ]["labels" ]
492+
493+ batch_images = fake_images
494+ batch_imglens [0 ] = 1
495+ batch_input_ids [0 ] = features [0 ]["input_ids" ]
496+
497+ mm_inputs = self .template .mm_plugin .get_mm_inputs (
498+ batch_images , batch_videos , batch_imglens , batch_vidlens , batch_input_ids , self .processor
499+ )
500+ if "token_type_ids" in mm_inputs :
501+ token_type_ids = mm_inputs .pop ("token_type_ids" )
502+ for i , feature in enumerate (features ):
503+ feature ["token_type_ids" ] = token_type_ids [i ]
504+
505+ features : Dict [str , "paddle.Tensor" ] = super ().__call__ (features )
506+
507+ if self .model is not None and hasattr (self .model , "get_rope_index" ): # for qwen2vl mrope
508+ features ["position_ids" ], features ["rope_deltas" ] = self .model .get_rope_index (
509+ input_ids = features ["input_ids" ],
510+ image_grid_thw = mm_inputs .get ("image_grid_thw" , None ),
511+ video_grid_thw = mm_inputs .get ("video_grid_thw" , None ),
512+ attention_mask = features ["attention_mask" ],
513+ )
514+
515+ if "cross_attention_mask" in mm_inputs : # for mllama inputs when pad_to_multiple_of is enabled
516+ cross_attention_mask = mm_inputs .pop ("cross_attention_mask" )
517+ seq_len = features ["input_ids" ].size (1 )
518+ orig_len = cross_attention_mask .size (1 )
519+ mm_inputs ["cross_attention_mask" ] = F .pad (cross_attention_mask , (0 , 0 , 0 , 0 , 0 , seq_len - orig_len ))
520+
521+ features .update (mm_inputs )
522+ if isinstance (features .get ("pixel_values" ), list ): # for pixtral inputs
523+ features = features .data # use default_collate() instead of BatchEncoding.to()
524+
525+ if "image_bound" in features : # for minicpmv inputs
526+ bsz , seq_length = features ["input_ids" ].shape
527+ features ["position_ids" ] = paddle .arange (seq_length ).long ().repeat (bsz , 1 )
528+ return {"data" : features , "input_ids" : features ["input_ids" ], "labels" : features ["labels" ]}
529+
530+ return features
531+
532+
533+
460534def main ():
461535 parser = PdArgumentParser ((ModelArguments , DataTrainingArguments , PreTrainingArguments ))
462536 if len (sys .argv ) == 2 and sys .argv [1 ].endswith (".json" ):
@@ -516,7 +590,7 @@ def main():
516590 MODEL_NAME = model_args .model_name_or_path
517591 model = Qwen2VLForConditionalGeneration .from_pretrained (MODEL_NAME , dtype = dtype )
518592 image_processor = Qwen2VLImageProcessor .from_pretrained (MODEL_NAME )
519- tokenizer = MIXQwen2Tokenizer .from_pretrained (MODEL_NAME )
593+ tokenizer = MIXQwen2Tokenizer .from_pretrained (MODEL_NAME , padding_side = "right" )
520594 processor = Qwen2VLProcessor (image_processor , tokenizer )
521595
522596 tokenizer .tokenizer_path = tokenizer_path
@@ -578,8 +652,10 @@ def _freeze_params(module):
578652 # set seed for paddle dataloaders
579653 set_seed (training_args .seed )
580654
581- data_collator = DataCollatorForSeq2Seq (
655+ data_collator = MultiModalDataCollatorForSeq2Seq (
582656 tokenizer = tokenizer ,
657+ template = TEMPLATES [data_args .conv_style ],
658+ processor = processor ,
583659 pad_to_multiple_of = 8 if training_args .do_train else None , # for shift short attention
584660 label_pad_token_id = IGNORE_INDEX ,
585661 )
0 commit comments