|
9 | 9 | from modelscope import BitsAndBytesConfig, GenerationConfig |
10 | 10 |
|
11 | 11 | from swift.trainers import (IntervalStrategy, Seq2SeqTrainer, |
12 | | - Seq2SeqTrainingArguments) |
13 | | -from swift.tuners import (LongLoRAConfig, LongLoRAModelType, LoraConfig, |
14 | | - LoRAConfig, NEFTuneConfig, Swift) |
| 12 | + Seq2SeqTrainingArguments, TrainerCallback) |
15 | 13 | from swift.utils import (check_json_format, compute_acc_metrics, |
16 | | - compute_nlg_metrics, freeze_model_parameters, |
17 | | - get_dist_setting, get_logger, get_main, |
18 | | - get_model_info, is_ddp_plus_mp, is_dist, is_master, |
19 | | - plot_images, preprocess_logits_for_metrics, |
| 14 | + compute_nlg_metrics, get_dist_setting, get_logger, |
| 15 | + get_main, get_model_info, is_ddp_plus_mp, is_dist, |
| 16 | + is_master, plot_images, preprocess_logits_for_metrics, |
20 | 17 | seed_everything, show_layers) |
21 | 18 | from .tuner import prepare_model |
22 | 19 | from .utils import (LazyLLMDataset, SftArguments, Template, |
23 | 20 | add_self_cognition_dataset, data_collate_fn, dataset_map, |
24 | | - find_all_linear_for_lora, get_additional_saved_files, |
25 | | - get_dataset, get_model_tokenizer, get_template, |
26 | | - get_time_info, print_example, set_generation_config, |
27 | | - sort_by_max_length, stat_dataset) |
| 21 | + get_additional_saved_files, get_dataset, |
| 22 | + get_model_tokenizer, get_template, get_time_info, |
| 23 | + print_example, set_generation_config, sort_by_max_length, |
| 24 | + stat_dataset) |
28 | 25 |
|
29 | 26 | logger = get_logger() |
30 | 27 |
|
@@ -234,13 +231,19 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]: |
234 | 231 | if args.check_model_is_latest is False: |
235 | 232 | trainer_kwargs['check_model'] = False |
236 | 233 |
|
| 234 | + class TrainerAdapterCallback(TrainerCallback): |
| 235 | + |
| 236 | + def on_train_begin(*args, **kwargs): |
| 237 | + model.set_active_adapters(model.adapters.keys(), offload='meta') |
| 238 | + |
237 | 239 | trainer = Seq2SeqTrainer( |
238 | 240 | model=model, |
239 | 241 | args=training_args, |
240 | 242 | data_collator=data_collator, |
241 | 243 | train_dataset=train_dataset, |
242 | 244 | eval_dataset=val_dataset, |
243 | 245 | tokenizer=tokenizer, |
| 246 | + callbacks=[TrainerAdapterCallback()], |
244 | 247 | **trainer_kwargs) |
245 | 248 | trainer.sft_args = args |
246 | 249 | if is_master(): |
|
0 commit comments