|
19 | 19 |
|
20 | 20 |
|
21 | 21 | def merge_lora(args: InferArguments, replace_if_exists=False) -> None: |
| 22 | + logger.info(f'replace_if_exists: {replace_if_exists}') |
22 | 23 | assert args.ckpt_dir is not None |
23 | 24 | assert args.sft_type == 'lora' |
24 | 25 | assert 'int4' not in args.model_type, 'int4 model is not supported' |
@@ -65,10 +66,21 @@ def merge_lora(args: InferArguments, replace_if_exists=False) -> None: |
65 | 66 | res.pop('adapter_cfg', None) |
66 | 67 | with open(new_configuration_path, 'w') as f: |
67 | 68 | json.dump(res, f, ensure_ascii=False, indent=4) |
68 | | - logger.info('Successfully merged LoRA.') |
| 69 | + # sft_args |
| 70 | + sft_args_fname = 'sft_args.json' |
| 71 | + old_sft_args_path = os.path.join(old_ckpt_dir, sft_args_fname) |
| 72 | + new_sft_args_path = os.path.join(args.ckpt_dir, sft_args_fname) |
| 73 | + if os.path.exists(old_sft_args_path): |
| 74 | + with open(old_sft_args_path, 'r') as f: |
| 75 | + res = json.load(f) |
| 76 | + res['sft_type'] = 'full' |
| 77 | + with open(new_sft_args_path, 'w') as f: |
| 78 | + json.dump(res, f, ensure_ascii=False, indent=2) |
| 79 | + logger.info(f'Successfully merged LoRA and saved in {args.ckpt_dir}.') |
69 | 80 | else: |
70 | | - logger.info('The weight directory for the merged LoRA already exists, ' |
71 | | - 'skipping the saving process.') |
| 81 | + logger.info( |
| 82 | + f'The weight directory for the merged LoRA already exists in {args.ckpt_dir}, ' |
| 83 | + 'skipping the saving process.') |
72 | 84 |
|
73 | 85 |
|
74 | 86 | def prepare_model_template( |
|
0 commit comments