1212from transformers .utils import is_torch_npu_available
1313
1414from swift .trainers import Seq2SeqTrainer
15+ from swift .trainers .utils import can_return_loss , find_labels
1516from swift .utils import (check_json_format , compute_acc_metrics ,
1617 compute_nlg_metrics , get_dist_setting , get_logger ,
1718 get_main , get_model_info , is_ddp_plus_mp , is_dist ,
1819 is_master , plot_images , preprocess_logits_for_metrics ,
19- seed_everything , show_layers )
20+ seed_everything , show_layers , use_torchacc )
21+ from .accelerator import ta_accelerate
2022from .tuner import prepare_model
2123from .utils import (TEMPLATE_MAPPING , LazyLLMDataset , SftArguments , Template ,
2224 add_self_cognition_dataset , dataset_map , get_dataset ,
@@ -55,15 +57,15 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
5557 model_kwargs = {'low_cpu_mem_usage' : True }
5658 if is_dist () and not is_ddp_plus_mp ():
5759 model_kwargs ['device_map' ] = {'' : local_rank }
58- else :
60+ elif not use_torchacc () :
5961 model_kwargs ['device_map' ] = 'auto'
62+
6063 if args .load_in_8bit or args .load_in_4bit :
6164 quantization_config = BitsAndBytesConfig (
6265 args .load_in_8bit ,
6366 args .load_in_4bit ,
6467 bnb_4bit_compute_dtype = args .bnb_4bit_compute_dtype ,
6568 bnb_4bit_quant_type = args .bnb_4bit_quant_type ,
66- bnb_4bit_quant_storage = args .bnb_4bit_quant_storage ,
6769 bnb_4bit_use_double_quant = args .bnb_4bit_use_double_quant )
6870 logger .info (f'quantization_config: { quantization_config .__dict__ } ' )
6971 model_kwargs ['quantization_config' ] = quantization_config
@@ -93,6 +95,13 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
9395 set_generation_config (model , generation_config )
9496 training_args .generation_config = generation_config
9597
98+ if use_torchacc ():
99+ import torchacc as ta
100+ # Get `label` and `return_loss` before 'ta_accelerate' because it will
101+ # wrapper the model and make these properties wrong.
102+ label_names = find_labels (model )
103+ return_loss = can_return_loss (model )
104+ model = ta .patch_qwen_model (model )
96105 # Preparing LoRA
97106 model , callbacks = prepare_model (model , args )
98107
@@ -108,6 +117,18 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
108117 logger .info ('Setting model.config.use_cache: False' )
109118 model .enable_input_require_grads ()
110119
120+ if use_torchacc ():
121+ model .config .use_cache = False
122+ logger .info ('Setting model.config.use_cache: False' )
123+ model = ta_accelerate (
124+ model ,
125+ world_size ,
126+ args .model_layer_cls_name ,
127+ args .bf16 ,
128+ args .fp16 ,
129+ gradient_checkpointing = True ,
130+ fsdp_flatten_parameters = False )
131+
111132 # Loading Dataset
112133 random_state = np .random .RandomState (args .dataset_seed )
113134 train_dataset , val_dataset = get_dataset (
@@ -185,6 +206,15 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
185206 padding_to = args .max_length if args .sft_type == 'longlora' else None
186207 data_collator = partial (template .data_collator , padding_to = padding_to )
187208
209+ trian_batch_size = args .batch_size
210+ eval_batch_size = args .eval_batch_size
211+ if use_torchacc ():
212+ trian_batch_size *= world_size
213+ eval_batch_size *= world_size
214+ training_args .per_device_train_batch_size = trian_batch_size
215+ training_args .per_device_eval_batch_size = eval_batch_size
216+ training_args .group_by_length = use_torchacc ()
217+
188218 # Trainer
189219 logger .info (f'training_args: { training_args } ' )
190220
@@ -211,6 +241,9 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
211241 callbacks = callbacks ,
212242 ** trainer_kwargs )
213243 trainer .sft_args = args
244+ if use_torchacc ():
245+ trainer .label_names = label_names
246+ trainer .can_return_loss = return_loss
214247 if is_master ():
215248 for args_obj , fname in zip ([args , training_args ],
216249 ['sft_args.json' , 'training_args.json' ]):
@@ -233,7 +266,7 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
233266 f'best_model_checkpoint: { trainer .state .best_model_checkpoint } ' )
234267 train_time = get_time_info (trainer .state .log_history , len (train_dataset ))
235268 # Visualization
236- if is_master ():
269+ if is_master () and not use_torchacc () :
237270 images_dir = os .path .join (args .output_dir , 'images' )
238271 logger .info (f'images_dir: { images_dir } ' )
239272 plot_images (images_dir , args .logging_dir , ['train/loss' ], 0.9 )
@@ -253,4 +286,14 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
253286 }
254287
255288
256- sft_main = get_main (SftArguments , llm_sft )
289+ def get_sft_main (args , llm ):
290+ if use_torchacc ():
291+ logger .warning ('TorchAcc is currently only available internally '
292+ 'within Alibaba Cloud.' )
293+ import torchacc as ta
294+ # This patch should be called before `llm_sft`.
295+ ta .accelerate_hf_trainer ()
296+ return get_main (args , llm )
297+
298+
299+ sft_main = get_sft_main (SftArguments , llm_sft )
0 commit comments