File tree Expand file tree Collapse file tree 2 files changed +17
-7
lines changed
paddleformers/transformers Expand file tree Collapse file tree 2 files changed +17
-7
lines changed Original file line number Diff line number Diff line change @@ -3496,10 +3496,15 @@ def decode_token(
3496
3496
# from byte fallback tokenization.
3497
3497
# If it's in the middle, it's probably a real invalid id generated
3498
3498
# by the model
3499
- prefix_index = new_text .index (prefix_text )
3500
- new_text = new_text [prefix_index + len (prefix_text ) :]
3501
- return new_text , read_offset , len (all_input_ids )
3499
+ if new_text .startswith (prefix_text ):
3500
+ prefix_index = new_text .index (prefix_text )
3501
+ new_text = new_text [prefix_index + len (prefix_text ) :]
3502
+ return new_text , read_offset , len (all_input_ids )
3503
+ else :
3504
+ return "" , prefix_offset , len (all_input_ids )
3502
3505
else :
3506
+ if len (all_input_ids [prefix_offset :]) > 3 :
3507
+ return new_text , len (all_input_ids ), len (all_input_ids )
3503
3508
return "" , prefix_offset , read_offset
3504
3509
3505
3510
def batch_decode (
Original file line number Diff line number Diff line change @@ -487,15 +487,20 @@ def decode_token(
487
487
all_input_ids [prefix_offset :], skip_special_tokens = skip_special_tokens , clean_up_tokenization_spaces = False
488
488
)
489
489
490
- if len (new_text ) > len (prefix_text ) and not prefix_text . endswith ( "�" ) and not new_text . endswith ( "�" ) :
490
+ if len (new_text ) > len (prefix_text ) and "�" not in prefix_text and "�" not in new_text :
491
491
# utf-8 char at the end means it's a potential unfinished byte sequence
492
492
# from byte fallback tokenization.
493
493
# If it's in the middle, it's probably a real invalid id generated
494
494
# by the model
495
- prefix_index = new_text .index (prefix_text )
496
- new_text = new_text [prefix_index + len (prefix_text ) :]
497
- return new_text , read_offset , len (all_input_ids )
495
+ if new_text .startswith (prefix_text ):
496
+ prefix_index = new_text .index (prefix_text )
497
+ new_text = new_text [prefix_index + len (prefix_text ) :]
498
+ return new_text , read_offset , len (all_input_ids )
499
+ else :
500
+ return "" , prefix_offset , len (all_input_ids )
498
501
else :
502
+ if len (all_input_ids [prefix_offset :]) > 3 :
503
+ return new_text , len (all_input_ids ), len (all_input_ids )
499
504
return "" , prefix_offset , read_offset
500
505
501
506
You can’t perform that action at this time.
0 commit comments