Skip to content

Commit 1982091

Browse files
authored
[Improvement] support system prompt in training dataset (#7667)
* update system prompt * support system prompt * add context system prompt docs * update chat-template readme
1 parent 6ddb4b1 commit 1982091

File tree

4 files changed

+40
-7
lines changed

4 files changed

+40
-7
lines changed

llm/data.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def tokenize_rounds_example(tokenizer, example, data_args):
105105
"""
106106

107107
# 0. prepare data
108+
context_data = example.get("context", {})
108109
example["src"] = example["src"] if isinstance(example["src"], list) else [example["src"]]
109110
example["tgt"] = example["tgt"] if isinstance(example["tgt"], list) else [example["tgt"]]
110111

@@ -113,7 +114,9 @@ def tokenize_rounds_example(tokenizer, example, data_args):
113114
conversations = [[src, tgt] for src, tgt in zip(example["src"], example["tgt"])]
114115

115116
# 1. only tokenize input_ids
116-
conversation_result: list[tuple[list[int], list[int]]] = tokenizer.encode_chat_inputs(conversations)
117+
conversation_result: list[tuple[list[int], list[int]]] = tokenizer.encode_chat_inputs(
118+
conversations, context_data=context_data
119+
)
117120
system_ids = conversation_result.pop("system", []) or []
118121

119122
# 2. truncate conversations based on conversation unit

llm/docs/chat_template.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,29 @@ python finetune_generation.py ... --chat_template ./qwen_14b_chat_template.json
5757
1.`chat_template` 参数和 `model_name_or_path` 参数一致时,此时将默认使用模型自带的 `chat_template.json` 文件。
5858
1.`chat_template` 参数为文件路径时,此时将使用该文件中的 `chat_template` 配置。
5959
1.`chat_template` 参数为空时,此时不使用 `chat_template` 配置进行训练。
60+
61+
#### 如何自定义system prompt
62+
63+
如果想要在训练或者推理的过程中动态调整 system prompt,需要进行以下调整:
64+
65+
1. 则需要保证 `chat_template.json` 文件中的 system 配置是包含jinja2 中的变量占位符(比如:`<|im_start|>user\n{{user}}<|im_end|>` 中的 {{user}} 就是一个变量占位符),同时尽量让其保留默认参数,比如上述配置可调整成:
66+
67+
> 需要开发者手动调整 `chat_template.json` 实现动态调整 system prompt。
68+
69+
```diff
70+
{
71+
- "system": "You are a helpful assistant.",
72+
+ "system": "{{system | 'You are a helpful assistant.'}}",
73+
"conversation": ["\n<|im_start|>user\n{{user}}<|im_end|>\n<|im_start|>assistant\n", "{{bot}}<|im_end|>"],
74+
"query": "\n<|im_start|>user\n{{query}}<|im_end|>\n<|im_start|>assistant\n",
75+
}
76+
```
77+
78+
2. 训练文本数据中需要配置 `context` 字段将 `system` 字段给传递进去,示例数据为:
79+
80+
```json
81+
{"src": ["user-1", "user-2", ..., "user-n"], "tgt": ["bot-1", "bot-2", ..., "bot-n"], "context": {"system": "你是一个擅长做任务的人工智能助手"}}
82+
...
83+
```
84+
85+
在渲染 chat_template 的时候将以上数据中的`context` 作为jinja2 的上下文数据,这样就可以在训练数据集中定制每个训练数据的 system prompt。

paddlenlp/transformers/chatglm_v2/tokenizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import annotations
1515

1616
import os
17-
from typing import Dict, List, Optional, Union
17+
from typing import Any, Dict, List, Optional, Union
1818

1919
import numpy as np
2020
from sentencepiece import SentencePieceProcessor
@@ -280,9 +280,9 @@ def _pad(
280280

281281
return encoded_inputs
282282

283-
def encode_chat_inputs(self, conversations: List[List[str, str]]):
283+
def encode_chat_inputs(self, conversations: List[List[str, str]], context_data: Dict[str, Any] = {}):
284284
# encode system
285-
result = super().encode_chat_inputs(conversations)
285+
result = super().encode_chat_inputs(conversations, context_data=context_data)
286286
if "system" in result:
287287
result["system"] = self.get_prefix_tokens() + result["system"]
288288
else:

paddlenlp/transformers/tokenizer_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -647,26 +647,30 @@ def apply_chat_template(
647647
tokenizer_kwargs["add_special_tokens"] = False
648648
return self(query, **tokenizer_kwargs)
649649

650-
def encode_chat_inputs(self, conversations: List[List[str, str]]):
650+
def encode_chat_inputs(self, conversations: List[List[str, str]], context_data: Dict[str, Any] = {}):
651651
"""Encodes conversation to pairs of token ids.
652652
Turn 0: bos + system + sep + user bot + eos
653653
Turn t: sep + bot + query bot + eos
654654
655655
Args:
656656
conversation (List[List[str, str]]): the conversation of data
657+
context_data (Dict[str, Any]): the context data of conversation
657658
658659
Returns:
659660
List[list[int], list[int]]: the pair of input_ids and target_ids
660661
"""
661662
# encode system
662663
result = {}
663664
if self.chat_template.system:
664-
result["system"] = self.encode(self.chat_template.system, add_special_tokens=False)["input_ids"]
665+
system = self.chat_template.render_system(context_data)
666+
result["system"] = self.encode(system, add_special_tokens=False)["input_ids"]
665667

666668
# encode conversation
667669
conversation_ids = []
668670
for index, conversation in enumerate(conversations):
669-
user_input, bot_output = self.chat_template.render_conversation(conversation, index=index)
671+
user_input, bot_output = self.chat_template.render_conversation(
672+
conversation, index=index, context_data=context_data
673+
)
670674
user_ids = self.encode(user_input, add_special_tokens=False)["input_ids"]
671675
bot_ids = self.encode(bot_output, add_special_tokens=False)["input_ids"]
672676
conversation_ids.append([user_ids, bot_ids])

0 commit comments

Comments
 (0)