1- use fancy_regex:: Regex ;
1+ use fancy_regex:: Regex as FancyRegex ;
2+ use regex:: Regex ;
23use rustc_hash:: FxHashMap as HashMap ;
34use rustc_hash:: FxHashSet as HashSet ;
45use thread_local:: ThreadLocal ;
@@ -133,9 +134,9 @@ pub struct CoreBPE {
133134 decoder : HashMap < Rank , & ' static [ u8 ] > ,
134135 special_tokens_decoder : HashMap < Rank , Vec < u8 > > ,
135136 regex : Regex ,
136- special_regex : Regex ,
137+ special_regex : FancyRegex ,
137138 regex_tls : ThreadLocal < Regex > ,
138- special_regex_tls : ThreadLocal < Regex > ,
139+ special_regex_tls : ThreadLocal < FancyRegex > ,
139140 sorted_token_bytes : Vec < & ' static [ u8 ] > ,
140141}
141142
@@ -144,7 +145,7 @@ impl CoreBPE {
144145 self . regex_tls . get_or ( || self . regex . clone ( ) )
145146 }
146147
147- fn _get_tl_special_regex ( & self ) -> & Regex {
148+ fn _get_tl_special_regex ( & self ) -> & FancyRegex {
148149 self . special_regex_tls . get_or ( || self . special_regex . clone ( ) )
149150 }
150151
@@ -161,24 +162,85 @@ impl CoreBPE {
161162 ret
162163 }
163164
164- fn _encode_ordinary_native ( & self , text : & str ) -> Vec < Rank > {
165+ fn _encode_ordinary_native_impl ( & self , text : & str , ret : & mut Vec < Rank > ) -> usize {
165166 // This is the core of the encoding logic; the other functions in here
166167 // just make things complicated :-)
167168 let regex = self . _get_tl_regex ( ) ;
168- let mut ret = vec ! [ ] ;
169+ let mut last_end = 0 ;
170+ let mut last_piece_token_len = 0 ;
171+ let mut piece: & [ u8 ] = & [ ] ;
169172 for mat in regex. find_iter ( text) {
170- let piece = mat. unwrap ( ) . as_str ( ) . as_bytes ( ) ;
173+ piece = mat. as_str ( ) . as_bytes ( ) ;
174+ let start = mat. start ( ) ;
175+ let end = mat. end ( ) ;
176+
177+ // If there is a whitespace gap between peice and the previous piece, add its tokens
178+ if last_end < start {
179+ // If current piece starts with a whitespace, the whole gap is one new piece
180+ if mat
181+ . as_str ( )
182+ . chars ( )
183+ . next ( )
184+ . map_or ( false , |c| c. is_whitespace ( ) )
185+ {
186+ let wpiece = text[ last_end..start] . as_bytes ( ) ;
187+ match self . encoder . get ( wpiece) {
188+ Some ( token) => ret. push ( * token) ,
189+ None => ret. extend ( & byte_pair_encode ( wpiece, & self . encoder ) ) ,
190+ }
191+ // otherwise the last char of gap makes a piece, and the rest (if any) makes another piece
192+ } else {
193+ let last_char_size = & text[ last_end..start]
194+ . chars ( )
195+ . next_back ( )
196+ . unwrap ( )
197+ . len_utf8 ( ) ;
198+ // Example for gpt4-o: for text "= 6", "=" and "6" are matches, " " is the gap,
199+ // so the gap makes just one piece
200+ if last_char_size < & ( start - last_end) {
201+ let wpiece1 = text[ last_end..start - last_char_size] . as_bytes ( ) ;
202+ match self . encoder . get ( wpiece1) {
203+ Some ( token) => ret. push ( * token) ,
204+ None => ret. extend ( & byte_pair_encode ( wpiece1, & self . encoder ) ) ,
205+ }
206+ }
207+ let wpiece2 = text[ start - last_char_size..start] . as_bytes ( ) ;
208+ match self . encoder . get ( wpiece2) {
209+ Some ( token) => ret. push ( * token) ,
210+ None => ret. extend ( & byte_pair_encode ( wpiece2, & self . encoder ) ) ,
211+ }
212+ }
213+ }
214+ last_end = end;
215+
216+ // Now add piece tokens
171217 match self . encoder . get ( piece) {
172218 Some ( token) => ret. push ( * token) ,
173219 None => ret. extend ( & byte_pair_encode ( piece, & self . encoder ) ) ,
174220 }
175221 }
176- ret
222+ // Gap of whitespaces at the end of text
223+ if last_end < text. len ( ) {
224+ piece = text[ last_end..text. len ( ) ] . as_bytes ( ) ;
225+ match self . encoder . get ( piece) {
226+ Some ( token) => ret. push ( * token) ,
227+ None => ret. extend ( & byte_pair_encode ( piece, & self . encoder ) ) ,
228+ }
229+ }
230+
231+ if !piece. is_empty ( ) {
232+ last_piece_token_len = match self . encoder . get ( piece) {
233+ Some ( token) => 1 ,
234+ None => byte_pair_encode ( piece, & self . encoder ) . len ( ) ,
235+ } ;
236+ } ;
237+
238+ last_piece_token_len
177239 }
178240
179241 fn _encode_native ( & self , text : & str , allowed_special : & HashSet < & str > ) -> ( Vec < Rank > , usize ) {
180242 let special_regex = self . _get_tl_special_regex ( ) ;
181- let regex = self . _get_tl_regex ( ) ;
243+
182244 let mut ret = vec ! [ ] ;
183245
184246 let mut start = 0 ;
@@ -201,17 +263,10 @@ impl CoreBPE {
201263 }
202264 let end = next_special. map_or ( text. len ( ) , |m| m. start ( ) ) ;
203265
204- // Okay, here we go, compare this logic to _encode_ordinary_native
205- for mat in regex. find_iter ( & text[ start..end] ) {
206- let piece = mat. unwrap ( ) . as_str ( ) . as_bytes ( ) ;
207- if let Some ( token) = self . encoder . get ( piece) {
208- last_piece_token_len = 1 ;
209- ret. push ( * token) ;
210- continue ;
211- }
212- let tokens = byte_pair_encode ( piece, & self . encoder ) ;
213- last_piece_token_len = tokens. len ( ) ;
214- ret. extend ( & tokens) ;
266+ if end > start {
267+ // regex is not created and passed here, but it seems harmless.
268+ last_piece_token_len =
269+ self . _encode_ordinary_native_impl ( & text[ start..end] , & mut ret) ;
215270 }
216271
217272 match next_special {
@@ -271,6 +326,13 @@ impl CoreBPE {
271326 ( tokens, last_piece_token_len)
272327 }
273328
329+ fn _encode_ordinary_native ( & self , text : & str ) -> Vec < Rank > {
330+ // This wrapper function is needed for those callers that do not pass ret.
331+ let mut ret = vec ! [ ] ;
332+ self . _encode_ordinary_native_impl ( text, & mut ret) ;
333+ ret
334+ }
335+
274336 fn _encode_unstable_native (
275337 & self ,
276338 text : & str ,
@@ -302,7 +364,7 @@ impl CoreBPE {
302364 // Separating this from the loop below helps with performance in a common case.
303365 let mut point = self
304366 . sorted_token_bytes
305- . partition_point ( |x| * x < unstable_bytes. as_slice ( ) ) ;
367+ . partition_point ( |x| & x [ .. ] < unstable_bytes. as_slice ( ) ) ;
306368 while point < self . sorted_token_bytes . len ( )
307369 && self . sorted_token_bytes [ point] . starts_with ( & unstable_bytes)
308370 {
@@ -318,9 +380,7 @@ impl CoreBPE {
318380 for i in 1 ..unstable_bytes. len ( ) {
319381 let prefix = & unstable_bytes[ ..i] ;
320382 let suffix = & unstable_bytes[ i..] ;
321- let mut point = self
322- . sorted_token_bytes
323- . partition_point ( |x| * x < suffix) ;
383+ let mut point = self . sorted_token_bytes . partition_point ( |x| & x[ ..] < suffix) ;
324384 // TODO: Perf optimisation if suffix starts with " "?
325385 while point < self . sorted_token_bytes . len ( )
326386 && self . sorted_token_bytes [ point] . starts_with ( suffix)
@@ -393,15 +453,15 @@ impl CoreBPE {
393453 encoder : HashMap < Vec < u8 > , Rank > ,
394454 special_tokens_encoder : HashMap < String , Rank > ,
395455 pattern : & str ,
396- ) -> Result < Self , fancy_regex :: Error > {
456+ ) -> Result < Self , regex :: Error > {
397457 let regex = Regex :: new ( pattern) ?;
398458
399459 let special_regex = {
400460 let parts = special_tokens_encoder
401461 . keys ( )
402462 . map ( |s| fancy_regex:: escape ( s) )
403463 . collect :: < Vec < _ > > ( ) ;
404- Regex :: new ( & parts. join ( "|" ) ) ?
464+ FancyRegex :: new ( & parts. join ( "|" ) ) . unwrap ( )
405465 } ;
406466
407467 // Use unsafe to extend the lifetime of references to the encoder's keys
0 commit comments