Skip to content

Commit b5b897a

Browse files
authored
Feat/multi round chat (#35)
1 parent 3d6400b commit b5b897a

File tree

12 files changed

+282
-52
lines changed

12 files changed

+282
-52
lines changed

README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,11 @@ Key features:
3333

3434
1. supported sft method: [lora](https://arxiv.org/abs/2106.09685), [qlora](https://arxiv.org/abs/2305.14314), full(full parameter fine tuning), ...
3535
2. supported models: qwen-7b, [qwen-7b-chat](https://github.com/QwenLM/Qwen-7B), qwen-vl, [qwen-vl-chat](https://github.com/QwenLM/Qwen-VL), baichuan-7b, baichuan-13b, baichuan-13b-chat, chatglm2-6b, chatglm2-6b-32k, llama2-7b, llama2-7b-chat, llama2-13b, llama2-13b-chat, llama2-70b, llama2-70b-chat, openbuddy-llama2-13b, openbuddy-llama-65b, polylm-13b
36-
3. supported feature: quantization, ddp, model parallelism(device map), gradient checkpoint, gradient accumulation steps, push to modelscope hub, custom datasets, ...
37-
4. supported datasets: alpaca-en(gpt4), alpaca-zh(gpt4), finance-en, multi-alpaca-all, code-en, instinwild-en, instinwild-zh, cot-en, cot-zh, coco-en
36+
3. supported feature: quantization, ddp, model parallelism(device map), gradient checkpoint, gradient accumulation steps, push to modelscope hub, custom datasets, multimodal and agent sft, mutli-round chat, ...
37+
4. supported datasets:
38+
1. nlp: alpaca-en(gpt4), alpaca-zh(gpt4), finance-en, multi-alpaca-all, code-en, instinwild-en, instinwild-zh, cot-en, cot-zh, firefly-all-zh, poetry-zh
39+
2. agent: damo-agent-zh, damo-agent-mini-zh
40+
3. multi-modal: coco-en
3841
5. supported templates: chatml(qwen), baichuan, chatglm2, llama, openbuddy_llama, default
3942

4043
# Installation

README_CN.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,11 @@ SWIFT(Scalable lightWeight Infrastructure for Fine-Tuning)是一个可扩展
3131

3232
1. 支持的sft方法: [lora](https://arxiv.org/abs/2106.09685), [qlora](https://arxiv.org/abs/2305.14314), 全参数微调, ...
3333
2. 支持的模型: qwen-7b, [qwen-7b-chat](https://github.com/QwenLM/Qwen-7B), qwen-vl, [qwen-vl-chat](https://github.com/QwenLM/Qwen-VL), baichuan-7b, baichuan-13b, baichuan-13b-chat, chatglm2-6b, chatglm2-6b-32k, llama2-7b, llama2-7b-chat, llama2-13b, llama2-13b-chat, llama2-70b, llama2-70b-chat, openbuddy-llama2-13b, openbuddy-llama-65b, polylm-13b
34-
3. 支持的特性: 模型量化, DDP, 模型并行(device_map), gradient checkpoint, 梯度累加, 支持推送modelscope hub, 支持自定义数据集, ...
35-
4. 支持的数据集: alpaca-en(gpt4), alpaca-zh(gpt4), finance-en, multi-alpaca-all, code-en, instinwild-en, instinwild-zh, cot-en, cot-zh, coco-en
34+
3. 支持的特性: 模型量化, DDP, 模型并行(device_map), gradient checkpoint, 梯度累加, 支持推送modelscope hub, 自定义数据集, 多模态和agent sft, 多轮对话, ...
35+
4. 支持的数据集:
36+
1. nlp: alpaca-en(gpt4), alpaca-zh(gpt4), finance-en, multi-alpaca-all, code-en, instinwild-en, instinwild-zh, cot-en, cot-zh, firefly-all-zh, poetry-zh
37+
2. agent: damo-agent-zh, damo-agent-mini-zh
38+
3. multi-modal: coco-en
3639
5. 支持的对话模板: chatml(qwen), baichuan, chatglm2, llama, openbuddy_llama, default
3740

3841
# 安装

examples/pytorch/llm/README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@
1717
## Features
1818
1. supported sft method: [lora](https://arxiv.org/abs/2106.09685), [qlora](https://arxiv.org/abs/2305.14314), full(full parameter fine tuning), ...
1919
2. supported models: qwen-7b, [qwen-7b-chat](https://github.com/QwenLM/Qwen-7B), qwen-vl, [qwen-vl-chat](https://github.com/QwenLM/Qwen-VL), baichuan-7b, baichuan-13b, baichuan-13b-chat, chatglm2-6b, chatglm2-6b-32k, llama2-7b, llama2-7b-chat, llama2-13b, llama2-13b-chat, llama2-70b, llama2-70b-chat, openbuddy-llama2-13b, openbuddy-llama-65b, polylm-13b
20-
3. supported feature: quantization, ddp, model parallelism(device map), gradient checkpoint, gradient accumulation steps, push to modelscope hub, custom datasets, ...
21-
4. supported datasets: alpaca-en(gpt4), alpaca-zh(gpt4), finance-en, multi-alpaca-all, code-en, instinwild-en, instinwild-zh, cot-en, cot-zh, coco-en
20+
3. supported feature: quantization, ddp, model parallelism(device map), gradient checkpoint, gradient accumulation steps, push to modelscope hub, custom datasets, multimodal and agent sft, mutli-round chat, ...
21+
4. supported datasets:
22+
1. nlp: alpaca-en(gpt4), alpaca-zh(gpt4), finance-en, multi-alpaca-all, code-en, instinwild-en, instinwild-zh, cot-en, cot-zh, firefly-all-zh, poetry-zh
23+
2. agent: damo-agent-zh, damo-agent-mini-zh
24+
3. multi-modal: coco-en
2225
5. supported templates: chatml(qwen), baichuan, chatglm2, llama, openbuddy_llama, default
2326

2427
## Prepare the Environment

examples/pytorch/llm/README_CN.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@
1818
## 特性
1919
1. 支持的sft方法: [lora](https://arxiv.org/abs/2106.09685), [qlora](https://arxiv.org/abs/2305.14314), 全参数微调, ...
2020
2. 支持的模型: qwen-7b, [qwen-7b-chat](https://github.com/QwenLM/Qwen-7B), qwen-vl, [qwen-vl-chat](https://github.com/QwenLM/Qwen-VL), baichuan-7b, baichuan-13b, baichuan-13b-chat, chatglm2-6b, chatglm2-6b-32k, llama2-7b, llama2-7b-chat, llama2-13b, llama2-13b-chat, llama2-70b, llama2-70b-chat, openbuddy-llama2-13b, openbuddy-llama-65b, polylm-13b
21-
3. 支持的特性: 模型量化, DDP, 模型并行(device_map), gradient checkpoint, 梯度累加, 支持推送modelscope hub, 支持自定义数据集, ...
22-
4. 支持的数据集: alpaca-en(gpt4), alpaca-zh(gpt4), finance-en, multi-alpaca-all, code-en, instinwild-en, instinwild-zh, cot-en, cot-zh, coco-en
21+
3. 支持的特性: 模型量化, DDP, 模型并行(device_map), gradient checkpoint, 梯度累加, 支持推送modelscope hub, 自定义数据集, 多模态和agent sft, 多轮对话, ...
22+
4. 支持的数据集:
23+
1. nlp: alpaca-en(gpt4), alpaca-zh(gpt4), finance-en, multi-alpaca-all, code-en, instinwild-en, instinwild-zh, cot-en, cot-zh, firefly-all-zh, poetry-zh
24+
2. agent: damo-agent-zh, damo-agent-mini-zh
25+
3. multi-modal: coco-en
2326
5. 支持的对话模板: chatml(qwen), baichuan, chatglm2, llama, openbuddy_llama, default
2427

2528
## 准备实验环境
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# 10G
2+
CUDA_VISIBLE_DEVICES=0 \
3+
python src/llm_infer.py \
4+
--model_type qwen-7b-chat \
5+
--sft_type lora \
6+
--template_type chatml \
7+
--dtype bf16 \
8+
--ckpt_dir "runs/qwen-7b-chat/vx_xxx/checkpoint-xxx" \
9+
--eval_human false \
10+
--dataset damo-agent-mini-zh \
11+
--dataset_sample -1 \
12+
--quantization_bit 4 \
13+
--bnb_4bit_comp_dtype bf16 \
14+
--max_new_tokens 1024 \
15+
--temperature 0.9 \
16+
--top_k 50 \
17+
--top_p 0.9 \
18+
--do_sample true \
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# 4 * 16GB VRAM
2+
nproc_per_node=4
3+
CUDA_VISIBLE_DEVICES=0,1,2,3 \
4+
torchrun \
5+
--nproc_per_node=$nproc_per_node \
6+
--master_port 29500 \
7+
src/llm_sft.py \
8+
--model_type qwen-7b-chat \
9+
--sft_type lora \
10+
--template_type chatml \
11+
--dtype bf16 \
12+
--output_dir runs \
13+
--ddp_backend nccl \
14+
--dataset damo-agent-mini-zh \
15+
--dataset_sample -1 \
16+
--num_train_epochs 1 \
17+
--max_length 2048 \
18+
--quantization_bit 4 \
19+
--bnb_4bit_comp_dtype bf16 \
20+
--lora_rank 64 \
21+
--lora_alpha 16 \
22+
--lora_dropout_p 0.05 \
23+
--lora_target_modules ALL \
24+
--batch_size 1 \
25+
--weight_decay 0. \
26+
--learning_rate 1e-4 \
27+
--gradient_accumulation_steps $(expr 16 / $nproc_per_node) \
28+
--max_grad_norm 0.5 \
29+
--warmup_ratio 0.03 \
30+
--eval_steps 50 \
31+
--save_steps 50 \
32+
--save_total_limit 2 \
33+
--logging_steps 10 \
34+
--use_flash_attn false \
35+
--push_to_hub false \
36+
--hub_model_id qwen-7b-chat-qlora \
37+
--hub_private_repo true \
38+
--hub_token 'your-sdk-token' \

examples/pytorch/llm/src/llm_sft.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,8 +293,10 @@ def llm_sft(args: SftArguments) -> None:
293293
model.config.use_cache = False
294294
model.enable_input_require_grads()
295295
if is_dist():
296+
trainer_args._frozen = False # Compatible with transformers==4.32.0
296297
trainer_args.ddp_find_unused_parameters = False
297298
trainer_args.ddp_broadcast_buffers = False
299+
trainer_args._frozen = True
298300
logger.info(f'trainer_args: {trainer_args}')
299301

300302
trainer = Seq2SeqTrainer(
Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .dataset import DATASET_MAPPING, get_dataset, process_dataset
22
from .model import MODEL_MAPPING, get_model_tokenizer
33
from .preprocess import TEMPLATE_MAPPING, get_preprocess
4-
from .utils import (broadcast_string, find_all_linear_for_lora,
5-
get_dist_setting, inference, is_dist, is_master,
6-
plot_images, select_bnb, select_dtype, show_layers)
4+
from .utils import (broadcast_string, download_dataset,
5+
find_all_linear_for_lora, get_dist_setting, inference,
6+
is_dist, is_master, plot_images, select_bnb, select_dtype,
7+
show_layers)

0 commit comments

Comments
 (0)