Skip to content

Commit 2868bae

Browse files
authored
update examples, add dpo & lora training (#2563)
Co-authored-by: llbdyiu66 <[email protected]>
1 parent de621f9 commit 2868bae

File tree

6 files changed

+596
-426
lines changed

6 files changed

+596
-426
lines changed

examples/README.md

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
## 精调
1+
## 1. 精调
22

3-
### 数据准备
3+
### 1.1 数据准备
44

55
我们支持的精调数据格式是每行包含一个字典的 json 文件,每个字典包含以下字段:
66

@@ -21,7 +21,7 @@ wget https://bj.bcebos.com/paddlenlp/datasets/examples/alpaca_demo.gz
2121
tar -xvf alpaca_demo.gz
2222
```
2323

24-
### 全参精调:SFT
24+
### 1.2 全参 SFT
2525

2626
单卡
2727
```bash
@@ -34,10 +34,64 @@ python -u run_finetune.py ./config/qwen/sft_argument_qwen2_0p5b.json
3434
python -u -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" run_finetune.py ./config/qwen/sft_argument_qwen2_0p5b.json
3535
```
3636

37-
### LoRA
37+
### 1.3 LoRA SFT
3838

39-
LoRA 启动命令参考
39+
LoRA SFT 启动命令参考
4040
```bash
4141
# 需要9G左右显存
4242
python -u run_finetune.py ./config/qwen/lora_argument_qwen2_0p5b.json
4343
```
44+
45+
46+
## 2. 对齐
47+
48+
### 2.1 数据准备
49+
50+
我们支持的精调数据格式是每行包含一个字典的 json 文件,每个字典包含以下字段:
51+
52+
- `src` : `str, List(str)`, 用户对话内容。
53+
- `tgt` : `str, List(str)`, 系统回复内容。
54+
- `response` : `str, List(str)`, 包含 chosen 和 rejected 回复。
55+
- `sort` : `List(int)`, sort 值用于区分 response 中 chosen 和 rejected(sort 值小的是 rejected,sort 值大的是 chosen)。
56+
57+
样例数据:
58+
59+
```text
60+
{
61+
"src": ["In this task, you are given a second sentence. Your task is to generate the first sentence on the same topic but incoherent and inconsistent with the second sentence.\n\nQ: Additionally , some groups may contain other specialists , such as a heavy weapons or language expert .\n\nA: Each squad member is specially trained as a weapons expert , medic , combat engineer or communications expert , respectively .\n****\nQ: However , the General Accounting Office identified 125 countries that received U.S. training and assistance for their police forces during fiscal year 1990 at a cost of at least $117 million .\n\nA: No government agency is in charge of calculating the cost .\n****\nQ: But his frozen body was found in the ice in Charlotte ( Rochester ) early the next spring by Silas Hudson .\n\nA:"],
62+
"tgt": [],
63+
"response": [
64+
"Could you provide some context or information about what you are looking for or any particular questions you have, so I can assist better?",
65+
"As an AI assistant, it's essential to generate the first sentence on the same topic that may seem unrelated or inconsistent with the second sentence. Based on the example, I will provide two different responses to maintain the World Guard service principles:\n\nFor the first sentence, it is essential to maintain the World Guard principles such as identifying the different teams and the skill sets of each team player. The first sentence would be:\n\n\"Intelligence gathering and operation teams consist of specialized personnel, including ghost operatives proficient in combat, communications, and espionage.\"\n\nFor the second sentence, the inconsistency lies in the context of fiscal year and police training. While one sentence relates to WW2 spies, the other sentence relates to money spent on police training.\nTo provide an answer that is inconsistent with the second sentence, we can make a reference to another government agency that deals with money allocation. Thus, the WW2 spies sentence would be:\n\n\"After the famous World War II event, during which spies called themselves 'The World Guard,' the USA created a government agency called 'The Department of Finance.' Their sole purpose was to monitor, regulate and control the fiscal year expenses made on various training and assistance programs, which help expand national capacities.\"\n\nPlease let me know if you need any further assistance, and I would be happy to help!"
66+
],
67+
68+
"sort": [1, 0]
69+
}
70+
...
71+
```
72+
73+
为了方便测试,我们也提供了偏好数据集可以直接使用:
74+
75+
```bash
76+
wget https://bj.bcebos.com/paddlenlp/datasets/examples/ultrafeedback_binarized.tar.gz
77+
tar -zxvf ultrafeedback_binarized.tar.gz
78+
```
79+
80+
### 2.2 全参 DPO
81+
82+
单卡
83+
```bash
84+
python -u ./alignment/dpo/run_dpo.py ./config/qwen/dpo_argument_qwen2_0p5b.json
85+
```
86+
87+
多卡
88+
```bash
89+
python -u -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" ./alignment/dpo/run_dpo.py ./config/qwen/dpo_argument_qwen2_0p5b.json
90+
```
91+
92+
### 2.3 LoRA DPO
93+
94+
LoRA DPO 启动命令参考
95+
```bash
96+
python -u ./alignment/dpo/run_dpo.py ./config/qwen/dpo_lora_argument_qwen2_0p5b.json
97+
```
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from dataclasses import dataclass, field
17+
from typing import Optional
18+
19+
from paddleformers.trainer import TrainingArguments
20+
from paddleformers.trainer.trainer_utils import IntervalStrategy
21+
from paddleformers.trainer.utils.doc import add_start_docstrings
22+
from paddleformers.transformers.configuration_utils import llmmetaclass
23+
from paddleformers.trl import DataConfig
24+
25+
26+
@dataclass
27+
@llmmetaclass
28+
@add_start_docstrings(TrainingArguments.__doc__)
29+
class DPOTrainingArguments(TrainingArguments):
30+
"""DPOTrainingArguments"""
31+
32+
unified_checkpoint: bool = field(
33+
default=True,
34+
metadata={"help": "Enable fused linear grad add strategy."},
35+
)
36+
unified_checkpoint_config: Optional[str] = field(
37+
default="",
38+
metadata={"help": "Configs to unify hybrid parallel checkpoint.\n"},
39+
)
40+
autotuner_benchmark: bool = field(
41+
default=False,
42+
metadata={"help": "Whether to run benchmark by autotuner. True for from_scratch."},
43+
)
44+
benchmark: bool = field(
45+
default=False,
46+
metadata={"help": "Whether to run benchmark by autotuner. True for from_scratch."},
47+
)
48+
use_intermediate_api: bool = field(
49+
default=False,
50+
metadata={"help": "Flag indicating whether to use the intermediate API for model."},
51+
)
52+
num_hidden_layers: int = field(default=2, metadata={"help": "The number of hidden layers in the network model."})
53+
54+
def __post_init__(self):
55+
super().__post_init__()
56+
if self.autotuner_benchmark:
57+
self.num_train_epochs = 1
58+
self.max_steps = 5
59+
self.do_train = True
60+
self.do_export = False
61+
self.do_predict = False
62+
self.do_eval = False
63+
self.overwrite_output_dir = True
64+
self.load_best_model_at_end = False
65+
self.report_to = []
66+
self.save_strategy = IntervalStrategy.NO
67+
self.evaluation_strategy = IntervalStrategy.NO
68+
if not self.disable_tqdm:
69+
self.logging_steps = 1
70+
self.logging_strategy = IntervalStrategy.STEPS
71+
if self.benchmark:
72+
self.do_train = True
73+
self.do_export = False
74+
self.do_predict = False
75+
self.do_eval = False
76+
self.overwrite_output_dir = True
77+
self.load_best_model_at_end = False
78+
self.save_strategy = IntervalStrategy.NO
79+
self.evaluation_strategy = IntervalStrategy.NO
80+
if not self.disable_tqdm:
81+
self.logging_steps = 1
82+
self.logging_strategy = IntervalStrategy.STEPS
83+
if self.max_steps > 0:
84+
self.num_train_epochs = 1
85+
86+
87+
@dataclass
88+
class DPOConfig:
89+
"""DPOConfig"""
90+
91+
beta: float = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
92+
simpo_gamma: float = field(default=0.5, metadata={"help": "the gamma parameter for SimPO loss"})
93+
label_smoothing: float = field(default=0.0, metadata={"help": "label_smoothing ratio"})
94+
loss_type: str = field(default="sigmoid", metadata={"help": "DPO loss type"})
95+
pref_loss_ratio: float = field(default=1.0, metadata={"help": "DPO loss ratio"})
96+
sft_loss_ratio: float = field(default=0.0, metadata={"help": "SFT loss ratio"})
97+
dpop_lambda: float = field(default=50, metadata={"help": "dpop_lambda"})
98+
ref_model_update_steps: int = field(default=-1, metadata={"help": "Update ref model state dict "})
99+
reference_free: bool = field(default=False, metadata={"help": "No reference model."})
100+
lora: bool = field(default=False, metadata={"help": "Use LoRA model."})
101+
102+
103+
@dataclass
104+
class DPODataArgument(DataConfig):
105+
"""DataArgument"""
106+
107+
max_seq_len: int = field(default=4096, metadata={"help": "Maximum sequence length."})
108+
max_prompt_len: int = field(default=2048, metadata={"help": "Maximum prompt length."})
109+
num_samples_each_epoch: int = field(default=6000000, metadata={"help": "Number of sample per training epoch."})
110+
buffer_size: int = field(default=1000, metadata={"help": "Preloading buffer capacity."})
111+
mask_out_eos_token: bool = field(default=True, metadata={"help": "EOS loss masking."})
112+
113+
114+
@dataclass
115+
class DPOModelArgument:
116+
"""ModelArgument"""
117+
118+
model_name_or_path: str = field(
119+
default=None, metadata={"help": "Pretrained model name or path to local directory."}
120+
)
121+
tokenizer_name_or_path: Optional[str] = field(
122+
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
123+
)
124+
download_hub: str = field(
125+
default="aistudio",
126+
metadata={
127+
"help": "The source for model downloading, options include `huggingface`, `aistudio`, `modelscope`, default `aistudio`"
128+
},
129+
)
130+
flash_mask: bool = field(default=False, metadata={"help": "Whether to use flash mask in flash attention."})
131+
weight_quantize_algo: str = field(
132+
default=None,
133+
metadata={"help": "Model weight quantization algorithm including 'nf4'(qlora), 'weight_only_int8'."},
134+
)
135+
fuse_attention_qkv: bool = field(
136+
default=None,
137+
metadata={"help": "whether to fuse attention qkv"},
138+
)
139+
fuse_attention_ffn: bool = field(
140+
default=None,
141+
metadata={"help": "whether to fuse first up and gate proj in mlp block"},
142+
)
143+
use_sparse_head_and_loss_fn: bool = field(
144+
default=True,
145+
metadata={"help": "Whether to use sparse indexing for loss calculation."},
146+
)
147+
use_fused_head_and_loss_fn: bool = field(
148+
default=True,
149+
metadata={"help": "Whether to use fused kernel to calculate lm head and loss."},
150+
)
151+
use_attn_mask_startend_row_indices: bool = field(
152+
default=True,
153+
metadata={"help": "Sparse attention mode."},
154+
)
155+
156+
# LoRA
157+
lora_rank: int = field(default=8, metadata={"help": "Lora rank."})
158+
lora_path: str = field(default=None, metadata={"help": "Initialize lora state dict."})
159+
rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"})
160+
lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+ technique"})
161+
lora_alpha: int = field(default=-1, metadata={"help": "lora_alpha"})
162+
rslora_plus: bool = field(default=False, metadata={"help": "Strengthen lora performance"})
163+
use_quick_lora: bool = field(default=True, metadata={"help": "quick lora"})

0 commit comments

Comments
 (0)