Skip to content

Commit de621f9

Browse files
authored
fix chat_template bug. (#2552)
1 parent d6501d3 commit de621f9

File tree

8 files changed

+80
-23
lines changed

8 files changed

+80
-23
lines changed

examples/config/qwen/lora_argument_qwen2_0p5b.json

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
11
{
22
"model_name_or_path": "PaddleNLP/Qwen2-0.5B-Instruct",
3-
"dataset_name_or_path": "./data/sft",
3+
"train_dataset_path": "./data/sft/train.json",
4+
"train_dataset_prob": "1.0",
5+
"train_dataset_type": "erniekit",
6+
"eval_dataset_path": "./data/sft/dev.json",
7+
"eval_dataset_prob": "1.0",
8+
"eval_dataset_type": "erniekit",
9+
"packing": true,
10+
"mix_strategy": "random",
411
"output_dir": "./checkpoints/qwen2_paddle_lora_ckpts",
12+
"max_seq_len": 8192,
513
"per_device_train_batch_size": 1,
614
"gradient_accumulation_steps": 4,
715
"per_device_eval_batch_size": 8,
@@ -32,5 +40,6 @@
3240
"unified_checkpoint": true,
3341
"use_flash_attention": false,
3442
"pissa": false,
35-
"use_mora": false
43+
"use_mora": false,
44+
"encode_one_turn": true
3645
}

examples/config/qwen/sft_argument_qwen2_0p5b.json

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"model_name_or_path": "/root/.cache/aistudio/hub/models/PaddleNLP/Qwen2-0.5B-Instruct",
2+
"model_name_or_path": "PaddleNLP/Qwen2-0.5B-Instruct",
33
"train_dataset_path": "./data/sft/train.json",
44
"train_dataset_prob": "1.0",
55
"train_dataset_type": "erniekit",
@@ -39,5 +39,6 @@
3939
"zero_padding": true,
4040
"flash_mask": true,
4141
"unified_checkpoint": true,
42-
"use_flash_attention": true
42+
"use_flash_attention": true,
43+
"encode_one_turn": true
4344
}

examples/run_finetune.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def neft_post_hook(module, input, output):
222222
"greedy_intokens": data_args.greedy_intokens,
223223
"packing": data_args.packing,
224224
"mix_strategy": data_args.mix_strategy,
225+
"encode_one_turn": data_args.encode_one_turn,
225226
}
226227

227228
train_dataset = create_dataset_sft(

paddleformers/datasets/dpo.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def create_dataset(**dataset_config):
116116
mask_out_eos_token=dataset_config["mask_out_eos_token"],
117117
packing=dataset_config["packing"],
118118
mix_strategy=dataset_config["mix_strategy"],
119+
encode_one_turn=dataset_config["encode_one_turn"],
119120
)
120121
return sequence_dataset
121122

@@ -389,6 +390,7 @@ def __init__(
389390
mask_out_eos_token: bool = True,
390391
packing: bool = False,
391392
mix_strategy: str = "random",
393+
encode_one_turn: bool = True,
392394
):
393395
self.example_dataset = dataset
394396
self.tokenizer = tokenizer
@@ -415,6 +417,7 @@ def __init__(
415417
self.mask_out_eos_token = mask_out_eos_token
416418
self.packing = packing
417419
self.mix_strategy = mix_strategy
420+
self.encode_one_turn = encode_one_turn
418421
self.num_samples_each_epoch = num_samples_each_epoch
419422

420423
# For new data concatenation mode
@@ -594,8 +597,12 @@ def __postprocess_before_concat(self, example):
594597
# encoded_messages: List[List[int]]
595598
if not self.tokenizer.chat_template:
596599
self.tokenizer.init_chat_template(NONE_CHAT_TEMPLATE)
597-
chosen_encoded_messages = self.tokenizer.encode_chat_inputs(example.chosen)
598-
rejected_encoded_messages = self.tokenizer.encode_chat_inputs(example.rejected)
600+
chosen_encoded_messages = self.tokenizer.encode_chat_inputs(
601+
example.chosen, encode_one_turn=self.encode_one_turn
602+
)
603+
rejected_encoded_messages = self.tokenizer.encode_chat_inputs(
604+
example.rejected, encode_one_turn=self.encode_one_turn
605+
)
599606

600607
# chosen/rejected response
601608
response_token_ids_list = []

paddleformers/datasets/finetuning.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def create_dataset(**dataset_config):
8282
greedy_intokens=dataset_config["greedy_intokens"],
8383
packing=dataset_config["packing"],
8484
mix_strategy=dataset_config["mix_strategy"],
85+
encode_one_turn=dataset_config["encode_one_turn"],
8586
)
8687
return sequence_dataset
8788

@@ -285,6 +286,7 @@ def __init__(
285286
greedy_intokens: bool = False,
286287
packing: bool = False,
287288
mix_strategy: str = "random",
289+
encode_one_turn: bool = True,
288290
):
289291
"""Initialize SequenceDataset.
290292
@@ -314,6 +316,7 @@ def __init__(
314316
self.greedy_intokens = greedy_intokens
315317
self.packing = packing
316318
self.mix_strategy = mix_strategy
319+
self.encode_one_turn = encode_one_turn
317320
self.num_samples_each_epoch = num_samples_each_epoch
318321
self.reverse = True
319322

@@ -536,7 +539,7 @@ def _postprocess_sequence(self, example, actual_example_num):
536539
if example.is_function_call:
537540
encoded_messages = self._postprocess_fc_sequence(example)
538541
else:
539-
encoded_messages = self.tokenizer.encode_chat_inputs(example.request)
542+
encoded_messages = self.tokenizer.encode_chat_inputs(example.request, encode_one_turn=self.encode_one_turn)
540543

541544
num_reserved_tokens_for_each_dialog = 1 # only break_turn_token or end_token
542545
num_reserved_tokens_for_each_turn = 8

paddleformers/transformers/tokenizer_utils.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,49 @@ def _encode_chat_inputs_openai_format(
377377

378378
return conversation_ids
379379

380+
def _encode_chat_inputs_oneturn(
381+
self,
382+
conversations: Dict[str, Any],
383+
add_generation_prompt=True,
384+
):
385+
conversation_dict = {} if "tools" not in conversations else {"tools": conversations["tools"]}
386+
conversation_dict["messages"] = (
387+
[conversations["messages"][0]] if conversations["messages"][0]["role"] == "system" else []
388+
)
389+
390+
if conversations["messages"][0]["role"] == "system":
391+
conversations["messages"] = conversations["messages"][1:]
392+
393+
cur_str = ""
394+
conversation_ids = []
395+
for idx in range(0, len(conversations["messages"]), 2):
396+
conversation_id = []
397+
conversation_dict["messages"].append(conversations["messages"][idx])
398+
round_str = self.apply_chat_template(
399+
conversation_dict["messages"], add_generation_prompt=True, tokenize=False
400+
)
401+
# query: user prefix + user content + assist prefix
402+
query = round_str[len(cur_str) :]
403+
input_ids = self.convert_tokens_to_ids(self.tokenize(query))
404+
conversation_id.append(input_ids)
405+
cur_str = round_str
406+
407+
if idx + 1 < len(conversations["messages"]):
408+
conversation_dict["messages"].append(conversations["messages"][idx + 1])
409+
round_str = self.apply_chat_template(
410+
conversation_dict["messages"], add_generation_prompt=False, tokenize=False
411+
)
412+
# answer: assistant content
413+
answer = round_str[len(cur_str) :]
414+
output_ids = self.convert_tokens_to_ids(self.tokenize(answer))
415+
conversation_id.append(output_ids)
416+
417+
conversation_ids.append(conversation_id)
418+
conversation_dict["messages"] = []
419+
cur_str = ""
420+
421+
return conversation_ids
422+
380423
def _extract_non_learnable_parts(self, origin_msg: List[Dict[str, str]], split_s: List[str]):
381424
"""Split the entire chat by specified words. Extract the non-learnable parts."""
382425
# TODO:We will upgrade this feature later
@@ -458,14 +501,18 @@ def encode_chat_inputs(
458501
if not self.chat_template:
459502
raise ValueError("chat_template is not set, please set chat_template first.")
460503
else:
504+
encode_one_turn = kwargs.pop("encode_one_turn", True)
461505
add_generation_prompt = kwargs.pop("add_generation_prompt", True)
462506
if not isinstance(conversations, dict):
463507
query = self._encode_chat_inputs(
464508
conversations, context_data, add_generation_prompt=add_generation_prompt
465509
)
466510
else:
467511
conversations.update(add_generation_prompt=add_generation_prompt)
468-
query = self._encode_chat_inputs_openai_format(conversations)
512+
if encode_one_turn:
513+
query = self._encode_chat_inputs_oneturn(conversations)
514+
else:
515+
query = self._encode_chat_inputs_openai_format(conversations)
469516
return query
470517

471518
def decode_token(

paddleformers/trl/sftdata_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ class DataConfig:
5757
"help": "Strategy to use in dataset mixing (random/concat/interleave) (undersampling/oversampling)."
5858
},
5959
)
60+
encode_one_turn: bool = field(
61+
default=True,
62+
metadata={"help": "Whether encode each round independently in a multi-round dialogue."},
63+
)
6064
packing: bool = field(
6165
default=True,
6266
metadata={"help": "Enable sequences packing in training."},

tests/transformers/test_hf_tokenizer.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -128,21 +128,6 @@ def test_dict_apply_chat_template(self):
128128

129129

130130
class TestPaddleTokenizerMethod(unittest.TestCase):
131-
def test_encode_chat_inputs(self):
132-
tokenizer = AutoTokenizer.from_pretrained("PaddleNLP/Qwen2.5-7B", download_hub="aistudio")
133-
query = [["你好", "您好,我是个人人工智能助手"], ["今天吃啥", "你可以选择不同的菜系"]]
134-
encode_text = tokenizer.encode_chat_inputs(query)
135-
dict_query = {
136-
"messages": [
137-
{"role": "user", "content": "你好"},
138-
{"role": "assistant", "content": "您好,我是个人人工智能助手"},
139-
{"role": "user", "content": "今天吃啥"},
140-
{"role": "assistant", "content": "你可以选择不同的菜系"},
141-
]
142-
}
143-
encode_dict_text = tokenizer.encode_chat_inputs(dict_query)
144-
self.assertListEqual(encode_text["conversations"], encode_dict_text)
145-
146131
def test_tokenizer_decode_token(self) -> None:
147132
tokenizer = AutoTokenizer.from_pretrained("PaddleNLP/Qwen2.5-7B", download_hub="aistudio")
148133
test_cases = ["1. 百度 2. 腾讯", "hello world! I like eating banana", "🤓😖", "🤓😖testtest"]

0 commit comments

Comments
 (0)