Skip to content

Commit b50db1c

Browse files
authored
[Bug Fixes] update chatglm1 tokenizer (PaddlePaddle#7870)
* update chatglm1 tokenizer * update additional_special_token * add is_training tag * fix linting
1 parent 2a8d138 commit b50db1c

File tree

4 files changed

+14
-2
lines changed

4 files changed

+14
-2
lines changed

llm/data.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ def tokenize_rounds_example(tokenizer, example, data_args):
106106

107107
# 0. prepare data
108108
context_data = example.get("context", {})
109+
context_data["is_training"] = True
110+
109111
example["src"] = example["src"] if isinstance(example["src"], list) else [example["src"]]
110112
example["tgt"] = example["tgt"] if isinstance(example["tgt"], list) else [example["tgt"]]
111113

llm/predictor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
AutoModelForCausalLM,
5050
AutoTokenizer,
5151
ChatGLMv2Tokenizer,
52+
ChatGLMTokenizer,
5253
LlamaTokenizer,
5354
PretrainedModel,
5455
PretrainedTokenizer,
@@ -240,7 +241,7 @@ def _preprocess(self, source):
240241
padding=True,
241242
# when use chat_template, it should not add special tokens
242243
# chatglm2 prefix-tokens can not be tokenized into ids
243-
add_special_tokens=self.tokenizer.chat_template is None or isinstance(self.tokenizer, ChatGLMv2Tokenizer),
244+
add_special_tokens=self.tokenizer.chat_template is None or isinstance(self.tokenizer, (ChatGLMv2Tokenizer, ChatGLMTokenizer)),
244245
)
245246
return tokenized_source
246247

@@ -924,7 +925,7 @@ def _preprocess(self, source):
924925
max_length=self.config.src_length,
925926
# if use chat_template, it will not add special_tokens
926927
add_special_tokens=self.tokenizer.chat_template is None
927-
or isinstance(self.tokenizer, ChatGLMv2Tokenizer),
928+
or isinstance(self.tokenizer, (ChatGLMv2Tokenizer, ChatGLMTokenizer)),
928929
)
929930
input_ids = tokens["input_ids"][0]
930931
length = len(input_ids)

paddlenlp/transformers/chatglm/tokenizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def __init__(
5757
num_image_tokens=20000,
5858
**kwargs
5959
) -> None:
60+
kwargs["additional_special_tokens"] = kwargs.pop("additional_special_tokens", []) + [gmask_token]
6061
super().__init__(
6162
pad_token=pad_token,
6263
unk_token=unk_token,

paddlenlp/transformers/tokenizer_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,11 @@ def render_query(self, query: str, index: int = 0, context_data: Dict[str, Union
562562
template = self._compile_jinja_template(self.query)
563563
return template.render(query=query, index=index, **context_data)
564564

565+
def _init_context_data(self, context_data: Dict[str, Union[int, str]] = {}) -> Dict[str, Union[int, str]]:
566+
"""init the context data for chat-template"""
567+
context_data["is_training"] = context_data.get("is_training", False)
568+
return context_data
569+
565570
def render_system(self, context_data: Dict[str, Union[int, str]] = {}) -> str:
566571
if self.system is None:
567572
return ""
@@ -633,6 +638,8 @@ def apply_chat_template(
633638
Returns:
634639
str | dict[str, numpy.ndarray | paddle.Tensor]: return the result of applied data
635640
"""
641+
context_data = self.chat_template._init_context_data(context_data)
642+
636643
if isinstance(conversation, str):
637644
conversation = [[conversation]]
638645
elif isinstance(conversation, list) and isinstance(conversation[0], str):
@@ -661,6 +668,7 @@ def encode_chat_inputs(self, conversations: List[List[str, str]], context_data:
661668
Returns:
662669
List[list[int], list[int]]: the pair of input_ids and target_ids
663670
"""
671+
context_data = self.chat_template._init_context_data(context_data)
664672
# encode system
665673
result = {}
666674
if self.chat_template.system:

0 commit comments

Comments
 (0)