Skip to content

Commit 3aa9f4c

Browse files
authored
[llm] Add KTO (PaddlePaddle#9689)
* add kto * add kto * add kto * add kto' * add kto * add * fix conflict * add llm
1 parent dff62a1 commit 3aa9f4c

30 files changed

+1352
-385
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@
129129
* 大模型预训练、精调(包含 SFT、PEFT 技术)、对齐、量化已支持 LLaMA 系列、Baichuan 系列、Bloom 系列、ChatGLM 系列、Mistral 系列、OPT 系列和 Qwen 系列,【LLM】模型预训练、精调、对齐、量化支持列表如下:
130130

131131

132-
| Model | Pretrain | SFT | LoRA | FlashMask | Prefix Tuning | DPO/SimPO/ORPO | RLHF | Mergekit | Quantization |
132+
| Model | Pretrain | SFT | LoRA | FlashMask | Prefix Tuning | DPO/SimPO/ORPO/KTO | RLHF | Mergekit | Quantization |
133133
|--------------------------------------------|:--------:|:---:|:----:|:---------:|:-------------:|:--------------:|:----:|:-----:|:------------:|
134134
| [Llama](./llm/config/llama) ||||||||||
135135
| [Qwen](./llm/config/qwen) ||||||| 🚧 || 🚧 |

llm/README.md

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
## 🛠️ 支持模型列表 🛠️
1818

19-
| Model | Pretrain | SFT | LoRA | Prefix Tuning | DPO/SimPO/ORPO | RLHF | Mergekit | Quantization | Torch convert |
19+
| Model | Pretrain | SFT | LoRA | Prefix Tuning | DPO/SimPO/ORPO/KTO | RLHF | Mergekit | Quantization | Torch convert |
2020
|----------------------------------------|----------|-----|------|---------------|----------------|------|-------|--------------|---------------|
2121
| [LLaMA](./config/llama) ||||||||||
2222
| [Qwen](./config/qwen) |||||| 🚧 || 🚧 ||
@@ -154,7 +154,7 @@ python run_finetune.py ./config/llama/pt_argument.json
154154

155155
### 3. 对齐
156156

157-
我们支持 DPO、RLHF 等偏好对齐策略。DPO 策略采用 zero_padding 策略,结合 FlashMask 策略,有效提升模型训练效率。
157+
我们支持 DPO、KTO、RLHF 等偏好对齐策略。DPO、KTO 策略采用 zero_padding 策略,结合 FlashMask 策略,有效提升模型训练效率。
158158

159159
#### 3.1 DPO
160160

@@ -183,7 +183,7 @@ python run_finetune.py ./config/llama/pt_argument.json
183183
...
184184
```
185185

186-
为了方便测试,我们也提供了广告生成数据集可以直接使用
186+
为了方便测试,我们也提供了偏好数据集可以直接使用
187187

188188
```bash
189189
wget https://bj.bcebos.com/paddlenlp/datasets/examples/ultrafeedback_binarized.tar.gz
@@ -196,9 +196,60 @@ tar -zxvf ultrafeedback_binarized.tar.gz
196196
# DPO 启动命令参考
197197
python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" ./alignment/dpo/run_dpo.py ./config/llama/dpo_argument.json
198198
```
199+
200+
##### LoRA DPO
201+
202+
```bash
203+
# DPO 启动命令参考
204+
python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" ./alignment/dpo/run_dpo.py ./config/llama/dpo_lora_argument.json
205+
```
199206
更多 DPO 技术细节和使用说明详见[DPO 文档](./docs/dpo.md)
200207

201-
#### 3.2 RLHF
208+
#### 3.2 KTO
209+
210+
##### 数据准备
211+
212+
我们支持的精调数据格式是每行包含一个字典的 json 文件,每个字典包含以下字段:
213+
214+
- `src` : `str, List(str)`, 用户对话内容。
215+
- `tgt` : `str, List(str)`, 系统回复内容。
216+
- `response` : `str, List(str)`, 包含 resoinse 回复。
217+
- `sort` : `List(int)`, sort 值用于区分 response 属于 chosen 和 rejected(0是 rejected,1是 chosen)。
218+
219+
样例数据:
220+
221+
```text
222+
{
223+
"src": ["In this task, you are given a second sentence. Your task is to generate the first sentence on the same topic but incoherent and inconsistent with the second sentence.\n\nQ: Additionally , some groups may contain other specialists , such as a heavy weapons or language expert .\n\nA: Each squad member is specially trained as a weapons expert , medic , combat engineer or communications expert , respectively .\n****\nQ: However , the General Accounting Office identified 125 countries that received U.S. training and assistance for their police forces during fiscal year 1990 at a cost of at least $117 million .\n\nA: No government agency is in charge of calculating the cost .\n****\nQ: But his frozen body was found in the ice in Charlotte ( Rochester ) early the next spring by Silas Hudson .\n\nA:"],
224+
"tgt": [],
225+
"response": [
226+
"Could you provide some context or information about what you are looking for or any particular questions you have, so I can assist better?"],
227+
"sort": [1]
228+
}
229+
...
230+
```
231+
232+
为了方便测试,我们也提供了偏好数据集可以直接使用:
233+
234+
```bash
235+
wget https://bj.bcebos.com/paddlenlp/datasets/examples/ultrafeedback_binarized_pointwise.tar.gz
236+
tar -zxvf ultrafeedback_binarized.tar.gz
237+
```
238+
239+
##### 全参 KTO
240+
241+
```bash
242+
# KTO 启动命令参考
243+
python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" ./alignment/kto/run_kto.py ./config/llama/kto_argument.json
244+
```
245+
##### LoRA KTO
246+
247+
```bash
248+
# KTO 启动命令参考
249+
python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" ./alignment/kto/run_kto.py ./config/llama/kto_lora_argument.json
250+
```
251+
252+
#### 3.3 RLHF
202253

203254
飞桨大模型套件提供了提供了基于强化学习 PPO 算法对 LLM 进行人类偏好对齐的代码及完整使用示例,支持**3D 分布式并行训练以及 rollout 阶段使用预测优化进行生成加速**。详细使用教程详见[RLHF 文档](./docs/rlhf.md)
204255

llm/alignment/dpo/run_dpo.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
Qwen2ForCausalLMPipe,
4646
register_sequence_parallel_allreduce_hooks,
4747
)
48+
from paddlenlp.transformers.configuration_utils import LlmMetaConfig
49+
from paddlenlp.transformers.refined_recompute import update_refined_recompute
4850
from paddlenlp.trl import (
4951
DPOTrainer,
5052
calculate_effective_tokens,
@@ -80,14 +82,14 @@ def main():
8082
hasattr(training_args, "pipeline_parallel_config")
8183
and "enable_clear_every_step_cache" in training_args.pipeline_parallel_config
8284
), "Should set '--pipeline_parallel_config enable_clear_every_step_cache' in bash script for pp."
83-
if model_args.sequence_parallel:
85+
if training_args.sequence_parallel:
8486
if training_args.pipeline_parallel_degree > 1:
8587
assert (
8688
hasattr(training_args, "pipeline_parallel_config")
8789
and "disable_partial_send_recv" in training_args.pipeline_parallel_config
8890
), "Should set '--pipeline_parallel_config disable_partial_send_recv' in bash script for pp with sp."
8991
if training_args.tensor_parallel_degree <= 1:
90-
model_args.sequence_parallel = False
92+
training_args.sequence_parallel = False
9193
logger.info("Tensor_parallel_degree = 1. Set sequence_parallel to False.")
9294
training_args.print_config(model_args, "Model")
9395
training_args.print_config(data_args, "Data")
@@ -117,39 +119,38 @@ def main():
117119
dtype = "bfloat16"
118120

119121
logger.info("Start to load model & tokenizer.")
120-
model_kwargs = dict(
121-
pretrained_model_name_or_path=model_args.model_name_or_path,
122-
dtype=dtype,
123-
tensor_parallel_degree=training_args.tensor_parallel_degree,
124-
tensor_parallel_rank=training_args.tensor_parallel_rank,
125-
recompute_granularity=training_args.recompute_granularity,
126-
use_flash_attention=training_args.use_flash_attention,
127-
tensor_parallel_output=training_args.tensor_parallel_output,
128-
use_fused_rms_norm=training_args.use_fused_rms_norm,
129-
use_fused_rope=training_args.use_fused_rope,
130-
use_fused_linear=training_args.use_fused_linear,
131-
use_fused_dropout_add=training_args.use_fused_dropout_add,
122+
123+
model_config = AutoConfig.from_pretrained(model_args.model_name_or_path, dtype=dtype)
124+
LlmMetaConfig.set_llm_config(model_config, training_args)
125+
model_config.refined_recompute = update_refined_recompute(
126+
training_args.refined_recompute,
127+
dpo_config.lora,
132128
)
129+
if not dpo_config.reference_free and not dpo_config.lora:
130+
ref_model_config = AutoConfig.from_pretrained(model_args.model_name_or_path, dtype=dtype)
131+
LlmMetaConfig.set_llm_config(ref_model_config, training_args)
132+
ref_model_config.refined_recompute = update_refined_recompute(
133+
training_args.refined_recompute,
134+
dpo_config.lora,
135+
)
133136

134137
if training_args.pipeline_parallel_degree > 1:
135138
model_class = AutoModelForCausalLMPipe
136-
model_kwargs["dpo_config"] = dpo_config
139+
model_config.dpo_config = dpo_config
137140
else:
138141
model_class = AutoModelForCausalLM
139142
if not training_args.autotuner_benchmark or model_args.weight_quantize_algo is not None:
140-
model = model_class.from_pretrained(**model_kwargs)
143+
model = model_class.from_pretrained(model_args.model_name_or_path, config=model_config)
141144
# for DPO save
142145
if not dpo_config.reference_free and not dpo_config.lora:
143-
config = AutoConfig.from_pretrained(**model_kwargs)
144-
ref_model = model_class.from_config(config, dtype=dtype)
146+
ref_model = model_class.from_config(ref_model_config)
145147
ref_model.set_state_dict(model.state_dict())
146148
else:
147149
ref_model = None
148150
else:
149-
config = AutoConfig.from_pretrained(**model_kwargs)
150-
model = model_class.from_config(config, dtype=dtype)
151+
model = model_class.from_config(model_config)
151152
if not dpo_config.reference_free and not dpo_config.lora:
152-
ref_model = model_class.from_config(config, dtype=dtype)
153+
ref_model = model_class.from_config(ref_model_config)
153154
else:
154155
ref_model = None
155156
if training_args.pipeline_parallel_degree > 1:
@@ -163,7 +164,7 @@ def main():
163164

164165
if training_args.sequence_parallel:
165166
register_sequence_parallel_allreduce_hooks(
166-
model, training_args.gradient_accumulation_steps, model_args.fuse_sequence_parallel_allreduce
167+
model, training_args.gradient_accumulation_steps, training_args.fuse_sequence_parallel_allreduce
167168
)
168169
if model_args.tokenizer_name_or_path is not None:
169170
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path)

llm/alignment/kto/kto_argument.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from dataclasses import dataclass, field
17+
from typing import Optional
18+
19+
from paddlenlp.trainer import TrainingArguments
20+
from paddlenlp.trainer.trainer_utils import IntervalStrategy
21+
from paddlenlp.trainer.utils.doc import add_start_docstrings
22+
from paddlenlp.transformers.configuration_utils import llmmetaclass
23+
24+
25+
@dataclass
26+
@llmmetaclass
27+
@add_start_docstrings(TrainingArguments.__doc__)
28+
class KTOTrainingArguments(TrainingArguments):
29+
"""KTOTrainingArguments"""
30+
31+
unified_checkpoint: bool = field(
32+
default=True,
33+
metadata={"help": "Enable fused linear grad add strategy."},
34+
)
35+
unified_checkpoint_config: Optional[str] = field(
36+
default="",
37+
metadata={"help": "Configs to unify hybrid parallel checkpoint.\n"},
38+
)
39+
autotuner_benchmark: bool = field(
40+
default=False,
41+
metadata={"help": "Whether to run benchmark by autotuner. True for from_scratch."},
42+
)
43+
benchmark: bool = field(
44+
default=False,
45+
metadata={"help": "Whether to run benchmark by autotuner. True for from_scratch."},
46+
)
47+
48+
def __post_init__(self):
49+
super().__post_init__()
50+
if self.autotuner_benchmark:
51+
self.num_train_epochs = 1
52+
self.max_steps = 5
53+
self.do_train = True
54+
self.do_export = False
55+
self.do_predict = False
56+
self.do_eval = False
57+
self.overwrite_output_dir = True
58+
self.load_best_model_at_end = False
59+
self.report_to = []
60+
self.save_strategy = IntervalStrategy.NO
61+
self.evaluation_strategy = IntervalStrategy.NO
62+
if not self.disable_tqdm:
63+
self.logging_steps = 1
64+
self.logging_strategy = IntervalStrategy.STEPS
65+
if self.benchmark:
66+
self.do_train = True
67+
self.do_export = False
68+
self.do_predict = False
69+
self.do_eval = False
70+
self.overwrite_output_dir = True
71+
self.load_best_model_at_end = False
72+
self.save_strategy = IntervalStrategy.NO
73+
self.evaluation_strategy = IntervalStrategy.NO
74+
if not self.disable_tqdm:
75+
self.logging_steps = 1
76+
self.logging_strategy = IntervalStrategy.STEPS
77+
if self.max_steps > 0:
78+
self.num_train_epochs = 1
79+
80+
81+
@dataclass
82+
class KTOConfig:
83+
"""KTOConfig"""
84+
85+
beta: float = field(default=0.1, metadata={"help": "the beta parameter for KTO loss"})
86+
desirable_weight: float = field(default=1.0, metadata={"help": "desirable_weight"})
87+
undesirable_weight: float = field(default=1.0, metadata={"help": "undesirable_weight"})
88+
lora: bool = field(default=False, metadata={"help": "Use LoRA model."})
89+
90+
91+
@dataclass
92+
class KTODataArgument:
93+
"""DataArgument"""
94+
95+
train_dataset_path: str = field(default="./data/train.jsonl", metadata={"help": "Path to the train dataset dir."})
96+
dev_dataset_path: str = field(default="./data/dev.jsonl", metadata={"help": "Path to the dev dataset dir."})
97+
max_seq_len: int = field(default=4096, metadata={"help": "Maximum sequence length."})
98+
max_prompt_len: int = field(default=2048, metadata={"help": "Maximum prompt length."})
99+
greedy_zero_padding: bool = field(
100+
default=False,
101+
metadata={"help": "Whether to use Greedy Zero Padding data stream."},
102+
)
103+
104+
105+
@dataclass
106+
class KTOModelArgument:
107+
"""ModelArgument"""
108+
109+
model_name_or_path: str = field(
110+
default=None, metadata={"help": "Pretrained model name or path to local directory."}
111+
)
112+
tokenizer_name_or_path: Optional[str] = field(
113+
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
114+
)
115+
flash_mask: bool = field(default=False, metadata={"help": "Whether to use flash mask in flash attention."})
116+
weight_quantize_algo: str = field(
117+
default=None,
118+
metadata={"help": "Model weight quantization algorithm including 'nf4'(qlora), 'weight_only_int8'."},
119+
)
120+
fuse_attention_qkv: bool = field(
121+
default=None,
122+
metadata={"help": "whether to fuse attention qkv"},
123+
)
124+
fuse_attention_ffn: bool = field(
125+
default=None,
126+
metadata={"help": "whether to fuse first up and gate proj in mlp block"},
127+
)
128+
# LoRA
129+
lora_rank: int = field(default=8, metadata={"help": "Lora rank."})
130+
lora_path: str = field(default=None, metadata={"help": "Initialize lora state dict."})
131+
rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"})
132+
lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+ technique"})
133+
lora_alpha: int = field(default=-1, metadata={"help": "lora_alpha"})
134+
rslora_plus: bool = field(default=False, metadata={"help": "Strengthen lora performance"})
135+
use_quick_lora: bool = field(default=True, metadata={"help": "quick lora"})

0 commit comments

Comments
 (0)