|
18 | 18 | seed_everything, show_layers) |
19 | 19 | from .utils import (SftArguments, Template, add_self_cognition_dataset, |
20 | 20 | data_collate_fn, dataset_map, find_all_linear_for_lora, |
21 | | - get_dataset, get_model_tokenizer, get_template, |
22 | | - print_example, sort_by_max_length, stat_dataset) |
| 21 | + get_additional_saved_files, get_dataset, |
| 22 | + get_model_tokenizer, get_template, print_example, |
| 23 | + set_generation_config, sort_by_max_length, stat_dataset) |
23 | 24 |
|
24 | 25 | logger = get_logger() |
25 | 26 |
|
@@ -182,11 +183,15 @@ def llm_sft(args: SftArguments) -> str: |
182 | 183 | pad_token_id=tokenizer.pad_token_id, |
183 | 184 | eos_token_id=tokenizer.eos_token_id) |
184 | 185 | logger.info(f'generation_config: {generation_config}') |
| 186 | + set_generation_config(model, generation_config) |
185 | 187 | evaluation_strategy = IntervalStrategy.STEPS |
186 | 188 | load_best_model_at_end = True |
187 | 189 | if val_dataset is None: |
188 | 190 | evaluation_strategy = IntervalStrategy.NO |
189 | 191 | load_best_model_at_end = False |
| 192 | + additional_saved_files = [] |
| 193 | + if args.sft_type == 'full': |
| 194 | + additional_saved_files = get_additional_saved_files(args.model_type) |
190 | 195 | training_args = Seq2SeqTrainingArguments( |
191 | 196 | output_dir=args.output_dir, |
192 | 197 | evaluation_strategy=evaluation_strategy, |
@@ -230,7 +235,8 @@ def llm_sft(args: SftArguments) -> str: |
230 | 235 | only_save_model=args.only_save_model, |
231 | 236 | train_sampler_random=args.train_sampler_random, |
232 | 237 | report_to=args.report_to, |
233 | | - deepspeed=args.deepspeed) |
| 238 | + deepspeed=args.deepspeed, |
| 239 | + additional_saved_files=additional_saved_files) |
234 | 240 |
|
235 | 241 | if args.gradient_checkpointing: |
236 | 242 | model.enable_input_require_grads() |
|
0 commit comments