@@ -2598,21 +2598,13 @@ export class PreTrainedTokenizer extends Callable {
25982598 this . decoder . end_of_word_suffix = this . model . end_of_word_suffix ;
25992599 }
26002600
2601- // Divide added tokens into those that left/right strip, and those that don't
2602- const added_tokens_with_strip = this . added_tokens . filter ( x => x . rstrip || x . lstrip ) ;
2603- const added_tokens_without_strip = this . added_tokens . filter ( x => ! x . rstrip && ! x . lstrip ) ;
2604- const split_regex = added_tokens_with_strip . length > 0 ? new RegExp (
2605- added_tokens_with_strip . slice ( )
2606- // Sort by length (desc) to avoid early partial matches
2607- . sort ( ( a , b ) => b . content . length - a . content . length )
2608- . map ( x => `${ x . lstrip ? '\\s*' : '' } (${ escapeRegExp ( x . content ) } )${ x . rstrip ? '\\s*' : '' } ` )
2609- . join ( '|' )
2610- ) : null ;
26112601 this . added_tokens_splitter = new DictionarySplitter (
2612- added_tokens_without_strip . map ( x => x . content ) ,
2613- split_regex ,
2602+ this . added_tokens . map ( x => x . content ) ,
26142603 ) ;
26152604
2605+ /** @type {Map<string, AddedToken> } */
2606+ this . added_tokens_map = new Map ( this . added_tokens . map ( x => [ x . content , x ] ) )
2607+
26162608 // Set mask token if present (otherwise will be undefined, which is fine)
26172609 this . mask_token = this . getToken ( 'mask_token' ) ;
26182610 this . mask_token_id = this . model . tokens_to_ids . get ( this . mask_token ) ;
@@ -2907,38 +2899,49 @@ export class PreTrainedTokenizer extends Callable {
29072899 // First, we take care of special tokens. Needed to avoid issues arising from
29082900 // normalization and/or pretokenization (which may not preserve special tokens)
29092901 const sections = this . added_tokens_splitter . split ( text ) ;
2910- const tokens = sections . map ( ( x , section_index ) => {
2911- const addedToken = this . added_tokens . find ( t => t . content === x ) ;
2912- if ( addedToken !== undefined ) {
2913- // Ignore added tokens
2914- return x
2915- } else {
2916- if ( this . remove_space === true ) {
2917- x = x . trim ( ) . split ( / \s + / ) . join ( ' ' ) ;
2918- }
2919- if ( this . do_lowercase_and_remove_accent ) {
2920- x = lowercase_and_remove_accent ( x ) ;
2921- }
29222902
2923- if ( this . normalizer !== null ) {
2924- x = this . normalizer ( x ) ;
2903+ // Process left/right stripping of added tokens
2904+ for ( let i = 0 ; i < sections . length ; ++ i ) {
2905+ const addedToken = this . added_tokens_map . get ( sections [ i ] ) ;
2906+ if ( addedToken ) {
2907+ if ( addedToken . lstrip && i > 0 ) {
2908+ sections [ i - 1 ] = sections [ i - 1 ] . trimEnd ( ) ;
29252909 }
2926-
2927- // If, after normalization, this section is empty (e.g., trimming whitespace),
2928- // we return an empty array
2929- if ( x . length === 0 ) {
2930- return [ ] ;
2910+ if ( addedToken . rstrip && i < sections . length - 1 ) {
2911+ sections [ i + 1 ] = sections [ i + 1 ] . trimStart ( ) ;
29312912 }
2913+ }
2914+ }
29322915
2933- const sectionTokens = ( this . pre_tokenizer !== null ) ? this . pre_tokenizer ( x , {
2934- section_index ,
2935- } ) : [ x ] ;
2916+ const tokens = sections . flatMap ( ( x , section_index ) => {
2917+ if ( x . length === 0 ) return [ ] ;
2918+ if ( this . added_tokens_map . has ( x ) ) return [ x ] ; // Return added tokens unchanged
29362919
2937- const tokens = this . model ( sectionTokens ) ;
2920+ if ( this . remove_space === true ) {
2921+ x = x . trim ( ) . split ( / \s + / ) . join ( ' ' ) ;
2922+ }
2923+ if ( this . do_lowercase_and_remove_accent ) {
2924+ x = lowercase_and_remove_accent ( x ) ;
2925+ }
2926+
2927+ if ( this . normalizer !== null ) {
2928+ x = this . normalizer ( x ) ;
2929+ }
29382930
2939- return tokens ;
2931+ // If, after normalization, this section is empty (e.g., trimming whitespace),
2932+ // we return an empty array
2933+ if ( x . length === 0 ) {
2934+ return [ ] ;
29402935 }
2941- } ) . flat ( ) ;
2936+
2937+ const sectionTokens = ( this . pre_tokenizer !== null ) ? this . pre_tokenizer ( x , {
2938+ section_index,
2939+ } ) : [ x ] ;
2940+
2941+ const tokens = this . model ( sectionTokens ) ;
2942+
2943+ return tokens ;
2944+ } ) ;
29422945
29432946 return tokens ;
29442947 }
0 commit comments