Skip to content

Commit 32e4a08

Browse files
authored
add no template (#2676)
1 parent 1934086 commit 32e4a08

File tree

5 files changed

+101
-24
lines changed

5 files changed

+101
-24
lines changed

examples/run_finetune.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ def neft_post_hook(module, input, output):
204204
"packing": data_args.packing,
205205
"mix_strategy": data_args.mix_strategy,
206206
"encode_one_turn": data_args.encode_one_turn,
207+
"use_template": data_args.use_template,
207208
}
208209

209210
train_dataset = create_dataset_sft(

paddleformers/datasets/finetuning.py

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def create_dataset(**dataset_config):
8383
packing=dataset_config["packing"],
8484
mix_strategy=dataset_config["mix_strategy"],
8585
encode_one_turn=dataset_config["encode_one_turn"],
86+
use_template=dataset_config["use_template"],
8687
)
8788
return sequence_dataset
8889

@@ -289,6 +290,7 @@ def __init__(
289290
packing: bool = False,
290291
mix_strategy: str = "random",
291292
encode_one_turn: bool = True,
293+
use_template: bool = True,
292294
):
293295
"""Initialize SequenceDataset.
294296
@@ -319,6 +321,7 @@ def __init__(
319321
self.packing = packing
320322
self.mix_strategy = mix_strategy
321323
self.encode_one_turn = encode_one_turn
324+
self.use_template = use_template
322325
self.num_samples_each_epoch = num_samples_each_epoch
323326
self.reverse = True
324327

@@ -536,12 +539,19 @@ def _postprocess_sequence(self, example, actual_example_num):
536539
Returns:
537540
Sequence: Processed sequence or None if invalid.
538541
"""
539-
if not self.tokenizer.chat_template:
540-
self.tokenizer.chat_template = NONE_CHAT_TEMPLATE
541-
if example.is_function_call:
542-
encoded_messages = self._postprocess_fc_sequence(example)
542+
if self.use_template:
543+
if not self.tokenizer.chat_template:
544+
self.tokenizer.chat_template = NONE_CHAT_TEMPLATE
545+
if example.is_function_call:
546+
encoded_messages = self._postprocess_fc_sequence(example)
547+
else:
548+
encoded_messages = self.tokenizer.encode_chat_inputs(
549+
example.request, encode_one_turn=self.encode_one_turn
550+
)
543551
else:
544-
encoded_messages = self.tokenizer.encode_chat_inputs(example.request, encode_one_turn=self.encode_one_turn)
552+
encoded_messages = self.tokenizer.encode_chat_inputs_with_no_template(
553+
example.request, encode_one_turn=self.encode_one_turn
554+
)
545555

546556
num_reserved_tokens_for_each_dialog = 1 # only break_turn_token or end_token
547557
num_reserved_tokens_for_each_turn = 8
@@ -585,26 +595,36 @@ def _postprocess_sequence(self, example, actual_example_num):
585595

586596
return None
587597

588-
if self.begin_token_id is not None and self.end_of_response_id is not None:
589-
# Maybe left truncated, so need to add begin_token
590-
if tokens[0] != self.begin_token_id:
591-
tokens = [self.begin_token_id] + tokens
592-
loss_mask = [0] + loss_mask
593-
594-
if len(tokens) > self.max_seq_len:
595-
raise RuntimeError(f"token_ids is too long: {len(tokens)}")
596-
597-
# Add EOS token at the end
598-
del tokens[-1]
599-
del loss_mask[-1]
600-
labels = tokens[1:] + [self.tokenizer.eos_token_id]
601-
602-
# end_of_response is a special token that indicates the end of the turn.
603-
# end_token is a special token that indicates the end of the answer.
604-
labels = [label if label != self.end_of_response_id else self.tokenizer.eos_token_id for label in labels]
598+
if self.use_template:
599+
if self.begin_token_id is not None and self.end_of_response_id is not None:
600+
# Maybe left truncated, so need to add begin_token
601+
if tokens[0] != self.begin_token_id:
602+
tokens = [self.begin_token_id] + tokens
603+
loss_mask = [0] + loss_mask
604+
605+
if len(tokens) > self.max_seq_len:
606+
raise RuntimeError(f"token_ids is too long: {len(tokens)}")
607+
608+
# Add EOS token at the end
609+
del tokens[-1]
610+
del loss_mask[-1]
611+
labels = tokens[1:] + [self.tokenizer.eos_token_id]
612+
613+
# end_of_response is a special token that indicates the end of the turn.
614+
# end_token is a special token that indicates the end of the answer.
615+
labels = [
616+
label if label != self.end_of_response_id else self.tokenizer.eos_token_id for label in labels
617+
]
618+
else:
619+
tokens = tokens[:-1] + [self.tokenizer.eos_token_id]
620+
labels = tokens[1:] + [-100]
621+
if len(tokens) > self.max_seq_len:
622+
raise RuntimeError(f"token_ids is too long: {len(tokens)}")
605623
else:
606-
tokens = tokens[:-1] + [self.tokenizer.eos_token_id]
607-
labels = tokens[1:] + [-100]
624+
oral_tokens = tokens
625+
tokens = oral_tokens[:-1]
626+
labels = oral_tokens[1:]
627+
loss_mask = loss_mask[1:]
608628
if len(tokens) > self.max_seq_len:
609629
raise RuntimeError(f"token_ids is too long: {len(tokens)}")
610630

paddleformers/transformers/tokenizer_utils.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,54 @@ def encode_chat_inputs(
515515
query = self._encode_chat_inputs_openai_format(conversations)
516516
return query
517517

518+
def encode_chat_inputs_with_no_template(
519+
self, conversations: List[List[str, str]] | Dict[str, Any], context_data: Dict[str, Any] = {}, **kwargs
520+
):
521+
"""
522+
Args:
523+
conversation (List[List[str, str]]): the conversation of data
524+
context_data (Dict[str, Any]): the context data of conversation
525+
526+
Returns:
527+
List[list[int], list[int]]: the pair of input_ids and target_ids
528+
"""
529+
assert isinstance(conversations, dict)
530+
531+
conversation_dict = {} if "tools" not in conversations else {"tools": conversations["tools"]}
532+
conversation_dict["messages"] = (
533+
[conversations["messages"][0]] if conversations["messages"][0]["role"] == "system" else []
534+
)
535+
536+
if conversations["messages"][0]["role"] == "system":
537+
conversations["messages"] = conversations["messages"][1:]
538+
539+
cur_str = ""
540+
conversation_ids = []
541+
for idx in range(0, len(conversations["messages"]), 2):
542+
conversation_id = []
543+
conversation_dict["messages"].append(conversations["messages"][idx])
544+
round_str = conversation_dict["messages"]
545+
# fake template
546+
tokenize_input = "".join(item["content"] for item in round_str)
547+
tokenize_input = tokenize_input[len(cur_str) :]
548+
input_ids = self.convert_tokens_to_ids(self.tokenize(tokenize_input))
549+
conversation_id.append(input_ids)
550+
cur_str = tokenize_input
551+
552+
if idx + 1 < len(conversations["messages"]):
553+
conversation_dict["messages"].append(conversations["messages"][idx + 1])
554+
round_str = conversation_dict["messages"]
555+
# fake template
556+
tokenize_input = "".join(item["content"] for item in round_str)
557+
tokenize_input = tokenize_input[len(cur_str) :]
558+
output_ids = self.convert_tokens_to_ids(self.tokenize(tokenize_input))
559+
conversation_id.append(output_ids)
560+
561+
conversation_ids.append(conversation_id)
562+
conversation_dict["messages"] = []
563+
cur_str = ""
564+
return conversation_ids
565+
518566
def decode_token(
519567
self,
520568
all_input_ids: List[int],

paddleformers/trl/sftdata_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ class DataConfig:
5757
"help": "Strategy to use in dataset mixing (random/concat/interleave) (undersampling/oversampling)."
5858
},
5959
)
60+
use_template: bool = field(
61+
default=True,
62+
metadata={"help": "Whether to use template in data processing."},
63+
)
6064
encode_one_turn: bool = field(
6165
default=True,
6266
metadata={"help": "Whether encode each round independently in a multi-round dialogue."},

tests/dataset/test_ernie_datasets.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def test_random_dataset_len(self):
4242
"packing": False,
4343
"mix_strategy": "random",
4444
"encode_one_turn": True,
45+
"use_template": True,
4546
}
4647

4748
train_dataset = create_dataset_sft(
@@ -71,6 +72,7 @@ def test_concat_dataset_len(self):
7172
"packing": False,
7273
"mix_strategy": "concat",
7374
"encode_one_turn": True,
75+
"use_template": True,
7476
}
7577

7678
train_dataset = create_dataset_sft(
@@ -100,6 +102,7 @@ def test_interleave_under_dataset_len(self):
100102
"packing": False,
101103
"mix_strategy": "interleave_under",
102104
"encode_one_turn": True,
105+
"use_template": True,
103106
}
104107

105108
train_dataset = create_dataset_sft(
@@ -129,6 +132,7 @@ def test_interleave_over_dataset_len(self):
129132
"packing": False,
130133
"mix_strategy": "interleave_over",
131134
"encode_one_turn": True,
135+
"use_template": True,
132136
}
133137

134138
train_dataset = create_dataset_sft(

0 commit comments

Comments
 (0)