2020import sys
2121import traceback
2222from dataclasses import dataclass , field
23- from typing import Dict , Optional , Sequence , Any
23+ from typing import Any , Dict , Optional , Sequence
2424
2525import numpy as np
2626import paddle
3131from paddlenlp .trainer import PdArgumentParser , TrainingArguments , set_seed
3232from paddlenlp .trainer .trainer import Trainer
3333from paddlenlp .trainer .trainer_utils import get_last_checkpoint
34+ from paddlenlp .transformers .processing_utils import ProcessorMixin
3435from PIL import Image , ImageFile , PngImagePlugin , UnidentifiedImageError
3536
3637from paddlemix .datasets .internvl_dataset import ConcatDataset , WeightedConcatDataset
4243 Qwen2VLImageProcessor ,
4344 Qwen2VLProcessor ,
4445)
45- from paddlenlp .transformers .processing_utils import ProcessorMixin
4646
4747Image .MAX_IMAGE_PIXELS = None
4848ImageFile .LOAD_TRUNCATED_IMAGES = True
@@ -355,7 +355,7 @@ def pure_text_get_item(self, data_item):
355355 attention_mask = attention_mask ,
356356 images = [],
357357 )
358-
358+
359359 return ret
360360
361361 def __getitem__ (self , i ) -> Dict [str , paddle .Tensor ]:
@@ -460,7 +460,7 @@ def __post_init__(self):
460460
461461 def __call__ (self , features : Sequence [Dict [str , Any ]]) -> Dict [str , "paddle.Tensor" ]:
462462 batch_images , batch_videos , batch_imglens , batch_vidlens , batch_input_ids = [], [], [], [], []
463-
463+
464464 for feature in features :
465465 images = feature .pop ("images" , None ) or []
466466 videos = feature .pop ("videos" , None ) or []
@@ -470,17 +470,15 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tens
470470 batch_vidlens .append (len (videos ))
471471 batch_input_ids .append (feature ["input_ids" ])
472472
473- if (
474- self .processor is not None and sum (batch_imglens ) == 0 and sum (batch_vidlens ) == 0
475- ):
473+ if self .processor is not None and sum (batch_imglens ) == 0 and sum (batch_vidlens ) == 0 :
476474 fake_messages = [{"role" : "user" , "content" : IMAGE_PLACEHOLDER }]
477475 fake_images = [Image .new ("RGB" , (64 , 64 ), (255 , 255 , 255 ))]
478476 fake_messages = self .template .mm_plugin .process_messages (fake_messages , fake_images , [], self .processor )
479477 fake_input_ids = self .tokenizer .encode (fake_messages [0 ]["content" ], add_special_tokens = False )
480478 fake_input_ids , _ = self .template .mm_plugin .process_token_ids (
481479 fake_input_ids , None , fake_images , [], self .tokenizer , self .processor
482480 )
483-
481+
484482 if self .tokenizer .padding_side == "right" :
485483 features [0 ]["input_ids" ] = features [0 ]["input_ids" ] + fake_input_ids
486484 features [0 ]["attention_mask" ] = features [0 ]["attention_mask" ] + [0 ] * len (fake_input_ids )
@@ -530,7 +528,6 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tens
530528 return features
531529
532530
533-
534531def main ():
535532 parser = PdArgumentParser ((ModelArguments , DataTrainingArguments , PreTrainingArguments ))
536533 if len (sys .argv ) == 2 and sys .argv [1 ].endswith (".json" ):
@@ -565,6 +562,16 @@ def main():
565562 "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
566563 )
567564
565+ if paddle .is_compiled_with_xpu () and training_args .gradient_accumulation_steps > 1 :
566+ try :
567+ from paddle_xpu .layers .nn .linear import LinearConfig # noqa: F401
568+
569+ LinearConfig .enable_accumulate_steps_opt ()
570+ LinearConfig .set_accumulate_steps (training_args .gradient_accumulation_steps )
571+ except ImportError :
572+ # It's OK, not use accumulate_steps optimization
573+ pass
574+
568575 # Load model
569576 if "npu" in paddle .get_device ():
570577 is_bfloat16_supported = True
0 commit comments