Skip to content

Commit 3baf4e5

Browse files
authored
update feat: merge lora (#82)
1 parent c4c68ad commit 3baf4e5

File tree

8 files changed

+98
-10
lines changed

8 files changed

+98
-10
lines changed

examples/pytorch/llm/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ cd swift/examples/pytorch/llm
7474
# If you want to push weights into modelscope hub during training, you need to set '--push_to_hub true'.
7575
# Recommended experimental environment: A100
7676
bash scripts/qwen_7b_chat/lora/sft.sh
77-
bash scripts/qwen_7b_chat/lora/infer.sh
77+
bash scripts/qwen_7b_chat/lora/merge_lora_and_infer.sh
7878

7979
# sft(lora+ddp) and infer qwen-7b-chat, Requires 2*38GB GPU memory.
8080
# Recommended experimental environment: A100
@@ -90,7 +90,7 @@ bash scripts/qwen_7b_chat/lora_mp_ddp/infer.sh
9090
# If you want to use quantification, you need to `pip install bitsandbytes -U`
9191
# Recommended experimental environment: A10, 3090
9292
bash scripts/qwen_7b_chat/qlora/sft.sh
93-
bash scripts/qwen_7b_chat/qlora/infer.sh
93+
bash scripts/qwen_7b_chat/qlora/merge_lora_and_infer.sh
9494

9595
# sft(qlora+ddp) and infer qwen-7b-chat, Requires 2*14GB GPU memory.
9696
# Recommended experimental environment: A10, 3090

examples/pytorch/llm/README_CN.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ cd swift/examples/pytorch/llm
7676
# 如果你想在训练时, 将权重push到modelscope hub中, 你需要设置`--push_to_hub true`.
7777
# 推荐的实验环境: A100
7878
bash scripts/qwen_7b_chat/lora/sft.sh
79-
bash scripts/qwen_7b_chat/lora/infer.sh
79+
bash scripts/qwen_7b_chat/lora/merge_lora_and_infer.sh
8080

8181
# 微调(lora+ddp)+推理 qwen-7b-chat, 需要2卡*38GB显存.
8282
# 推荐的实验环境: A100
@@ -92,7 +92,7 @@ bash scripts/qwen_7b_chat/lora_mp_ddp/infer.sh
9292
# 如果你想要使用量化, 你需要`pip install bitsandbytes -U`
9393
# 推荐的实验环境: 3090, A10
9494
bash scripts/qwen_7b_chat/qlora/sft.sh
95-
bash scripts/qwen_7b_chat/qlora/infer.sh
95+
bash scripts/qwen_7b_chat/qlora/merge_lora_and_infer.sh
9696

9797
# 微调(qlora+ddp)+推理 qwen-7b-chat, 需要2卡*14GB显存.
9898
# 推荐的实验环境: 3090, A10

examples/pytorch/llm/scripts/qwen_7b_chat/lora/infer.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ python src/llm_infer.py \
66
--dtype bf16 \
77
--ckpt_dir "output/qwen-7b-chat/vx_xxx/checkpoint-xxx" \
88
--eval_human false \
9-
--dataset cot-en,cot-zh \
10-
--max_length 2048 \
9+
--dataset damo-agent-mini-zh \
10+
--max_length 4096 \
1111
--use_flash_attn true \
12-
--max_new_tokens 1024 \
12+
--max_new_tokens 2048 \
1313
--temperature 0.9 \
1414
--top_k 20 \
1515
--top_p 0.9 \
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
CUDA_VISIBLE_DEVICES=0 \
2+
python src/merge_lora_and_infer.py \
3+
--model_type qwen-7b-chat \
4+
--sft_type lora \
5+
--template_type chatml \
6+
--dtype bf16 \
7+
--ckpt_dir "output/qwen-7b-chat/vx_xxx/checkpoint-xxx" \
8+
--eval_human false \
9+
--dataset damo-agent-mini-zh \
10+
--max_length 4096 \
11+
--use_flash_attn true \
12+
--max_new_tokens 2048 \
13+
--temperature 0.9 \
14+
--top_k 20 \
15+
--top_p 0.9 \
16+
--do_sample true \

examples/pytorch/llm/scripts/qwen_7b_chat/lora/sft.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ python src/llm_sft.py \
77
--template_type chatml \
88
--dtype bf16 \
99
--output_dir output \
10-
--dataset cot-en,cot-zh \
11-
--train_dataset_sample 50000 \
10+
--dataset damo-agent-mini-zh \
11+
--train_dataset_sample -1 \
1212
--num_train_epochs 1 \
13-
--max_length 2048 \
13+
--max_length 4096 \
1414
--lora_rank 8 \
1515
--lora_alpha 32 \
1616
--lora_dropout_p 0. \
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
CUDA_VISIBLE_DEVICES=0 \
2+
python src/merge_lora_and_infer.py \
3+
--model_type qwen-7b-chat \
4+
--sft_type lora \
5+
--template_type chatml \
6+
--dtype bf16 \
7+
--ckpt_dir "output/qwen-7b-chat/vx_xxx/checkpoint-xxx" \
8+
--eval_human false \
9+
--dataset advertise-gen \
10+
--max_length 2048 \
11+
--quantization_bit 4 \
12+
--bnb_4bit_comp_dtype bf16 \
13+
--use_flash_attn false \
14+
--max_new_tokens 1024 \
15+
--temperature 0.9 \
16+
--top_k 20 \
17+
--top_p 0.9 \
18+
--do_sample true \
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
import os
3+
4+
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
5+
import torch
6+
from transformers import BitsAndBytesConfig, GenerationConfig, TextStreamer
7+
from utils import (InferArguments, get_dataset, get_model_tokenizer,
8+
get_preprocess)
9+
from llm_infer import llm_infer
10+
from swift import Swift, get_logger
11+
from swift.tuners import LoRA
12+
from swift.utils import inference, parse_args, seed_everything
13+
14+
logger = get_logger()
15+
16+
17+
def merge_lora(args: InferArguments) -> None:
18+
assert args.sft_type == 'lora'
19+
args.init_argument()
20+
logger.info(f'device_count: {torch.cuda.device_count()}')
21+
22+
# ### Loading Model and Tokenizer
23+
model, tokenizer = get_model_tokenizer(
24+
args.model_type, torch_dtype=args.torch_dtype, device_map='cpu')
25+
26+
# ### Preparing LoRA
27+
model = Swift.from_pretrained(model, args.ckpt_dir, inference_mode=True)
28+
if not hasattr(model, 'peft_type'):
29+
LoRA.unpatch_lora(model, model.adapters['default'].config, 'default')
30+
else:
31+
model.merge_and_unload()
32+
33+
new_ckpt_dir = os.path.abspath(
34+
os.path.join(args.ckpt_dir, '..', 'output_ckpt'))
35+
logger.info(f'new_ckpt_dir: `{new_ckpt_dir}`')
36+
logger.info("Setting args.sft_type: 'full'")
37+
logger.info(f'Setting args.ckpt_dir: {new_ckpt_dir}')
38+
args.ckpt_dir = new_ckpt_dir
39+
args.sft_type = 'full'
40+
if not os.path.exists(args.ckpt_dir):
41+
model.model.save_pretrained(args.ckpt_dir)
42+
tokenizer.save_pretrained(args.ckpt_dir)
43+
44+
45+
if __name__ == '__main__':
46+
args, remaining_argv = parse_args(InferArguments)
47+
if len(remaining_argv) > 0:
48+
if args.ignore_args_error:
49+
logger.warning(f'remaining_argv: {remaining_argv}')
50+
else:
51+
raise ValueError(f'remaining_argv: {remaining_argv}')
52+
merge_lora(args)
53+
llm_infer(args)

swift/utils/llm_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def inference(input_ids: List[int],
115115
tokenizer,
116116
streamer: Optional[TextStreamer] = None) -> str:
117117
generation_config = getattr(model, 'generation_config', None)
118+
streamer.skip_prompt = True
118119
print(f'[INFERENCE]{tokenizer.decode(input_ids)}', end='')
119120
input_ids = torch.tensor(input_ids)[None].cuda()
120121
attention_mask = torch.ones_like(input_ids)

0 commit comments

Comments
 (0)