Skip to content

Commit 524f2a2

Browse files
authored
[BufFix] fix decode_token (#2553)
1 parent 4384816 commit 524f2a2

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

paddleformers/transformers/legacy/tokenizer_utils_base.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3496,10 +3496,15 @@ def decode_token(
34963496
# from byte fallback tokenization.
34973497
# If it's in the middle, it's probably a real invalid id generated
34983498
# by the model
3499-
prefix_index = new_text.index(prefix_text)
3500-
new_text = new_text[prefix_index + len(prefix_text) :]
3501-
return new_text, read_offset, len(all_input_ids)
3499+
if new_text.startswith(prefix_text):
3500+
prefix_index = new_text.index(prefix_text)
3501+
new_text = new_text[prefix_index + len(prefix_text) :]
3502+
return new_text, read_offset, len(all_input_ids)
3503+
else:
3504+
return "", prefix_offset, len(all_input_ids)
35023505
else:
3506+
if len(all_input_ids[prefix_offset:]) > 3:
3507+
return new_text, len(all_input_ids), len(all_input_ids)
35033508
return "", prefix_offset, read_offset
35043509

35053510
def batch_decode(

paddleformers/transformers/tokenizer_utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -487,15 +487,20 @@ def decode_token(
487487
all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=False
488488
)
489489

490-
if len(new_text) > len(prefix_text) and not prefix_text.endswith("�") and not new_text.endswith("�"):
490+
if len(new_text) > len(prefix_text) and "�" not in prefix_text and "�" not in new_text:
491491
# utf-8 char at the end means it's a potential unfinished byte sequence
492492
# from byte fallback tokenization.
493493
# If it's in the middle, it's probably a real invalid id generated
494494
# by the model
495-
prefix_index = new_text.index(prefix_text)
496-
new_text = new_text[prefix_index + len(prefix_text) :]
497-
return new_text, read_offset, len(all_input_ids)
495+
if new_text.startswith(prefix_text):
496+
prefix_index = new_text.index(prefix_text)
497+
new_text = new_text[prefix_index + len(prefix_text) :]
498+
return new_text, read_offset, len(all_input_ids)
499+
else:
500+
return "", prefix_offset, len(all_input_ids)
498501
else:
502+
if len(all_input_ids[prefix_offset:]) > 3:
503+
return new_text, len(all_input_ids), len(all_input_ids)
499504
return "", prefix_offset, read_offset
500505

501506

0 commit comments

Comments
 (0)