Skip to content

Commit 4dea862

Browse files
authored
Add cli merge lora (#168)
1 parent e78dcdb commit 4dea862

File tree

7 files changed

+57
-5
lines changed

7 files changed

+57
-5
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,10 @@ CUDA_VISIBLE_DEVICES=0 swift infer --model_id_or_path qwen/Qwen-7B-Chat --datase
245245

246246
# Fine-tuned Model
247247
CUDA_VISIBLE_DEVICES=0 swift infer --ckpt_dir 'xxx/vx_xxx/checkpoint-xxx'
248+
249+
# Merge LoRA incremental weights and perform inference
250+
swift merge-lora --ckpt_dir 'xxx/vx_xxx/checkpoint-xxx'
251+
CUDA_VISIBLE_DEVICES=0 swift infer --ckpt_dir 'xxx/vx_xxx/checkpoint-xxx-merged'
248252
```
249253

250254
**Web-UI**:
@@ -254,6 +258,10 @@ CUDA_VISIBLE_DEVICES=0 swift web-ui --model_id_or_path qwen/Qwen-7B-Chat
254258

255259
# Fine-tuned Model
256260
CUDA_VISIBLE_DEVICES=0 swift web-ui --ckpt_dir 'xxx/vx_xxx/checkpoint-xxx'
261+
262+
# Merge LoRA incremental weights and use web UI
263+
swift merge-lora --ckpt_dir 'xxx/vx_xxx/checkpoint-xxx'
264+
CUDA_VISIBLE_DEVICES=0 swift web-ui --ckpt_dir 'xxx/vx_xxx/checkpoint-xxx-merged'
257265
```
258266

259267

README_CN.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,10 @@ CUDA_VISIBLE_DEVICES=0 swift infer --model_id_or_path qwen/Qwen-7B-Chat --datase
242242

243243
# 微调后的模型
244244
CUDA_VISIBLE_DEVICES=0 swift infer --ckpt_dir 'xxx/vx_xxx/checkpoint-xxx'
245+
246+
# merge LoRA增量权重并推理
247+
swift merge-lora --ckpt_dir 'xxx/vx_xxx/checkpoint-xxx'
248+
CUDA_VISIBLE_DEVICES=0 swift infer --ckpt_dir 'xxx/vx_xxx/checkpoint-xxx-merged'
245249
```
246250

247251
**Web-UI**:
@@ -251,6 +255,10 @@ CUDA_VISIBLE_DEVICES=0 swift web-ui --model_id_or_path qwen/Qwen-7B-Chat
251255

252256
# 微调后的模型
253257
CUDA_VISIBLE_DEVICES=0 swift web-ui --ckpt_dir 'xxx/vx_xxx/checkpoint-xxx'
258+
259+
# merge LoRA增量权重并使用web-ui
260+
swift merge-lora --ckpt_dir 'xxx/vx_xxx/checkpoint-xxx'
261+
CUDA_VISIBLE_DEVICES=0 swift web-ui --ckpt_dir 'xxx/vx_xxx/checkpoint-xxx-merged'
254262
```
255263

256264

examples/pytorch/llm/README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,10 @@ CUDA_VISIBLE_DEVICES=0 swift infer --model_id_or_path qwen/Qwen-7B-Chat --datase
208208

209209
# Fine-tuned Model
210210
CUDA_VISIBLE_DEVICES=0 swift infer --ckpt_dir 'xxx/vx_xxx/checkpoint-xxx'
211+
212+
# Merge LoRA incremental weights and perform inference
213+
swift merge-lora --ckpt_dir 'xxx/vx_xxx/checkpoint-xxx'
214+
CUDA_VISIBLE_DEVICES=0 swift infer --ckpt_dir 'xxx/vx_xxx/checkpoint-xxx-merged'
211215
```
212216

213217
**Web-UI**:
@@ -217,6 +221,10 @@ CUDA_VISIBLE_DEVICES=0 swift web-ui --model_id_or_path qwen/Qwen-7B-Chat
217221

218222
# Fine-tuned Model
219223
CUDA_VISIBLE_DEVICES=0 swift web-ui --ckpt_dir 'xxx/vx_xxx/checkpoint-xxx'
224+
225+
# Merge LoRA incremental weights and use web UI
226+
swift merge-lora --ckpt_dir 'xxx/vx_xxx/checkpoint-xxx'
227+
CUDA_VISIBLE_DEVICES=0 swift web-ui --ckpt_dir 'xxx/vx_xxx/checkpoint-xxx-merged'
220228
```
221229

222230

examples/pytorch/llm/README_CN.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,10 @@ CUDA_VISIBLE_DEVICES=0 swift infer --model_id_or_path qwen/Qwen-7B-Chat --datase
207207

208208
# 微调后的模型
209209
CUDA_VISIBLE_DEVICES=0 swift infer --ckpt_dir 'xxx/vx_xxx/checkpoint-xxx'
210+
211+
# merge LoRA增量权重并推理
212+
swift merge-lora --ckpt_dir 'xxx/vx_xxx/checkpoint-xxx'
213+
CUDA_VISIBLE_DEVICES=0 swift infer --ckpt_dir 'xxx/vx_xxx/checkpoint-xxx-merged'
210214
```
211215

212216
**Web-UI**:
@@ -216,6 +220,10 @@ CUDA_VISIBLE_DEVICES=0 swift web-ui --model_id_or_path qwen/Qwen-7B-Chat
216220

217221
# 微调后的模型
218222
CUDA_VISIBLE_DEVICES=0 swift web-ui --ckpt_dir 'xxx/vx_xxx/checkpoint-xxx'
223+
224+
# merge LoRA增量权重并使用web-ui
225+
swift merge-lora --ckpt_dir 'xxx/vx_xxx/checkpoint-xxx'
226+
CUDA_VISIBLE_DEVICES=0 swift web-ui --ckpt_dir 'xxx/vx_xxx/checkpoint-xxx-merged'
219227
```
220228

221229

swift/cli/main.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
import sys
44
from typing import Dict, List, Optional
55

6-
from swift.cli import infer, sft, web_ui
6+
from swift.cli import infer, merge_lora, sft, web_ui
77

88
ROUTE_MAPPING: Dict[str, str] = {
99
'sft': sft.__file__,
1010
'infer': infer.__file__,
11-
'web-ui': web_ui.__file__
11+
'web-ui': web_ui.__file__,
12+
'merge-lora': merge_lora.__file__
1213
}
14+
1315
ROUTE_MAPPING.update(
1416
{k.replace('-', '_'): v
1517
for k, v in ROUTE_MAPPING.items()})

swift/cli/merge_lora.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from swift.llm import InferArguments, merge_lora
2+
from swift.utils import parse_args
3+
4+
if __name__ == '__main__':
5+
args, remaining_argv = parse_args(InferArguments)
6+
merge_lora(args, replace_if_exists=True)

swift/llm/infer.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020

2121
def merge_lora(args: InferArguments, replace_if_exists=False) -> None:
22+
logger.info(f'replace_if_exists: {replace_if_exists}')
2223
assert args.ckpt_dir is not None
2324
assert args.sft_type == 'lora'
2425
assert 'int4' not in args.model_type, 'int4 model is not supported'
@@ -65,10 +66,21 @@ def merge_lora(args: InferArguments, replace_if_exists=False) -> None:
6566
res.pop('adapter_cfg', None)
6667
with open(new_configuration_path, 'w') as f:
6768
json.dump(res, f, ensure_ascii=False, indent=4)
68-
logger.info('Successfully merged LoRA.')
69+
# sft_args
70+
sft_args_fname = 'sft_args.json'
71+
old_sft_args_path = os.path.join(old_ckpt_dir, sft_args_fname)
72+
new_sft_args_path = os.path.join(args.ckpt_dir, sft_args_fname)
73+
if os.path.exists(old_sft_args_path):
74+
with open(old_sft_args_path, 'r') as f:
75+
res = json.load(f)
76+
res['sft_type'] = 'full'
77+
with open(new_sft_args_path, 'w') as f:
78+
json.dump(res, f, ensure_ascii=False, indent=2)
79+
logger.info(f'Successfully merged LoRA and saved in {args.ckpt_dir}.')
6980
else:
70-
logger.info('The weight directory for the merged LoRA already exists, '
71-
'skipping the saving process.')
81+
logger.info(
82+
f'The weight directory for the merged LoRA already exists in {args.ckpt_dir}, '
83+
'skipping the saving process.')
7284

7385

7486
def prepare_model_template(

0 commit comments

Comments
 (0)