|
19 | 19 | import os
|
20 | 20 | import re
|
21 | 21 | from functools import wraps
|
22 |
| -from typing import Any, Dict, List, Optional, Union |
| 22 | +from typing import Any, Dict, List, Optional, Tuple, Union |
23 | 23 |
|
24 | 24 | from transformers import BatchEncoding
|
25 | 25 | from transformers.tokenization_utils import (
|
@@ -447,6 +447,36 @@ def encode_chat_inputs(
|
447 | 447 | query = self._encode_chat_inputs_openai_format(conversations)
|
448 | 448 | return query
|
449 | 449 |
|
| 450 | + def decode_token( |
| 451 | + self, |
| 452 | + all_input_ids: List[int], |
| 453 | + prefix_offset: int = 0, |
| 454 | + read_offset: int = 0, |
| 455 | + skip_special_tokens: bool = False, |
| 456 | + ) -> Tuple[str, int, int]: |
| 457 | + """tokenizer decoding for the streaming generation use case. This method can be overridden for tokenizer that doesn't follow this API""" |
| 458 | + # The prefix text is necessary only to defeat cleanup algorithms in the decode |
| 459 | + # which decide to add a space or not depending on the surrounding ids. |
| 460 | + prefix_text = self.decode( |
| 461 | + all_input_ids[prefix_offset:read_offset], |
| 462 | + skip_special_tokens=skip_special_tokens, |
| 463 | + clean_up_tokenization_spaces=False, |
| 464 | + ) |
| 465 | + new_text = self.decode( |
| 466 | + all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=False |
| 467 | + ) |
| 468 | + |
| 469 | + if len(new_text) > len(prefix_text) and not prefix_text.endswith("�") and not new_text.endswith("�"): |
| 470 | + # utf-8 char at the end means it's a potential unfinished byte sequence |
| 471 | + # from byte fallback tokenization. |
| 472 | + # If it's in the middle, it's probably a real invalid id generated |
| 473 | + # by the model |
| 474 | + prefix_index = new_text.index(prefix_text) |
| 475 | + new_text = new_text[prefix_index + len(prefix_text) :] |
| 476 | + return new_text, read_offset, len(all_input_ids) |
| 477 | + else: |
| 478 | + return "", prefix_offset, read_offset |
| 479 | + |
450 | 480 |
|
451 | 481 | def warp_tokenizer(hf_tokenizer_class: PreTrainedTokenizer_tf):
|
452 | 482 | return type(hf_tokenizer_class.__name__, (PaddleTokenizerMixin, hf_tokenizer_class), {})
|
|
0 commit comments