Skip to content

Commit 7051550

Browse files
authored
add decode_token function (#2519)
1 parent 52504b7 commit 7051550

File tree

1 file changed

+31
-1
lines changed

1 file changed

+31
-1
lines changed

paddleformers/transformers/tokenizer_utils.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import os
2020
import re
2121
from functools import wraps
22-
from typing import Any, Dict, List, Optional, Union
22+
from typing import Any, Dict, List, Optional, Tuple, Union
2323

2424
from transformers import BatchEncoding
2525
from transformers.tokenization_utils import (
@@ -447,6 +447,36 @@ def encode_chat_inputs(
447447
query = self._encode_chat_inputs_openai_format(conversations)
448448
return query
449449

450+
def decode_token(
451+
self,
452+
all_input_ids: List[int],
453+
prefix_offset: int = 0,
454+
read_offset: int = 0,
455+
skip_special_tokens: bool = False,
456+
) -> Tuple[str, int, int]:
457+
"""tokenizer decoding for the streaming generation use case. This method can be overridden for tokenizer that doesn't follow this API"""
458+
# The prefix text is necessary only to defeat cleanup algorithms in the decode
459+
# which decide to add a space or not depending on the surrounding ids.
460+
prefix_text = self.decode(
461+
all_input_ids[prefix_offset:read_offset],
462+
skip_special_tokens=skip_special_tokens,
463+
clean_up_tokenization_spaces=False,
464+
)
465+
new_text = self.decode(
466+
all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=False
467+
)
468+
469+
if len(new_text) > len(prefix_text) and not prefix_text.endswith("�") and not new_text.endswith("�"):
470+
# utf-8 char at the end means it's a potential unfinished byte sequence
471+
# from byte fallback tokenization.
472+
# If it's in the middle, it's probably a real invalid id generated
473+
# by the model
474+
prefix_index = new_text.index(prefix_text)
475+
new_text = new_text[prefix_index + len(prefix_text) :]
476+
return new_text, read_offset, len(all_input_ids)
477+
else:
478+
return "", prefix_offset, read_offset
479+
450480

451481
def warp_tokenizer(hf_tokenizer_class: PreTrainedTokenizer_tf):
452482
return type(hf_tokenizer_class.__name__, (PaddleTokenizerMixin, hf_tokenizer_class), {})

0 commit comments

Comments
 (0)