|
11 | 11 | from utils import (SftArguments, dataset_map, get_dataset, get_model_tokenizer, |
12 | 12 | get_preprocess) |
13 | 13 |
|
14 | | -from swift import (LoraConfig, LoRAConfig, Seq2SeqTrainer, |
15 | | - Seq2SeqTrainingArguments, Swift, get_logger) |
| 14 | +from swift import (LongLoRAConfig, LongLoRAModelType, LoraConfig, LoRAConfig, |
| 15 | + Seq2SeqTrainer, Seq2SeqTrainingArguments, Swift, get_logger) |
16 | 16 | from swift.utils import (add_version_to_work_dir, broadcast_string, |
17 | 17 | check_json_format, compute_nlg_metrics, |
18 | 18 | data_collate_fn, find_all_linear_for_lora, |
@@ -54,27 +54,40 @@ def llm_sft(args: SftArguments) -> None: |
54 | 54 | args.model_type, torch_dtype=args.torch_dtype, **kwargs) |
55 | 55 |
|
56 | 56 | # ### Preparing LoRA |
57 | | - if args.sft_type == 'lora': |
| 57 | + if args.sft_type == 'lora' or args.sft_type == 'longlora': |
58 | 58 | if args.resume_from_checkpoint is None: |
59 | 59 | if 'ALL' in args.lora_target_modules: |
60 | 60 | assert len(args.lora_target_modules) == 1 |
61 | 61 | args.lora_target_modules = find_all_linear_for_lora( |
62 | 62 | model, args.quantization_bit, args.model_type) |
63 | 63 | logger.info( |
64 | 64 | f'Setting lora_target_modules: {args.lora_target_modules}') |
65 | | - lora_kwargs = {} |
66 | | - if args.tuner_bankend == 'peft': |
67 | | - global LoRAConfig |
68 | | - LoRAConfig = LoraConfig |
69 | | - lora_kwargs['task_type'] = 'CAUSAL_LM' |
70 | | - lora_config = LoRAConfig( |
71 | | - r=args.lora_rank, |
72 | | - target_modules=args.lora_target_modules, |
73 | | - lora_alpha=args.lora_alpha, |
74 | | - lora_dropout=args.lora_dropout_p, |
75 | | - **lora_kwargs) |
76 | | - model = Swift.prepare_model(model, lora_config) |
77 | | - logger.info(f'lora_config: {lora_config}') |
| 65 | + if args.sft_type == 'lora': |
| 66 | + lora_kwargs = {} |
| 67 | + if args.tuner_bankend == 'peft': |
| 68 | + global LoRAConfig |
| 69 | + LoRAConfig = LoraConfig |
| 70 | + lora_kwargs['task_type'] = 'CAUSAL_LM' |
| 71 | + lora_config = LoRAConfig( |
| 72 | + r=args.lora_rank, |
| 73 | + target_modules=args.lora_target_modules, |
| 74 | + lora_alpha=args.lora_alpha, |
| 75 | + lora_dropout=args.lora_dropout_p, |
| 76 | + **lora_kwargs) |
| 77 | + model = Swift.prepare_model(model, lora_config) |
| 78 | + logger.info(f'lora_config: {lora_config}') |
| 79 | + elif args.sft_type == 'longlora': |
| 80 | + assert args.tuner_bankend != 'peft' |
| 81 | + assert LongLoRAModelType.LLAMA in args.model_type |
| 82 | + longlora_config = LongLoRAConfig( |
| 83 | + r=args.lora_rank, |
| 84 | + target_modules=args.lora_target_modules, |
| 85 | + lora_alpha=args.lora_alpha, |
| 86 | + lora_dropout=args.lora_dropout_p, |
| 87 | + model_type=LongLoRAModelType.LLAMA, |
| 88 | + use_flash_attn=args.use_flash_attn) |
| 89 | + model = Swift.prepare_model(model, longlora_config) |
| 90 | + logger.info(f'longlora_config: {longlora_config}') |
78 | 91 | else: |
79 | 92 | model = Swift.from_pretrained( |
80 | 93 | model, args.resume_from_checkpoint, is_trainable=True) |
@@ -109,7 +122,10 @@ def llm_sft(args: SftArguments) -> None: |
109 | 122 | if args.test_oom_error: |
110 | 123 | train_dataset = sort_by_max_length(train_dataset, 20000) |
111 | 124 | # Data analysis |
112 | | - data_collator = partial(data_collate_fn, tokenizer=tokenizer) |
| 125 | + data_collator = partial( |
| 126 | + data_collate_fn, |
| 127 | + tokenizer=tokenizer, |
| 128 | + padding_to=args.max_length if args.sft_type == 'longlora' else None) |
113 | 129 | print_example(train_dataset[0], tokenizer) |
114 | 130 | stat_dataset(train_dataset) |
115 | 131 | stat_dataset(val_dataset) |
|
0 commit comments