22#![ allow( clippy:: borrow_deref_ref) ]
33
44use std:: collections:: HashSet ;
5+ use std:: num:: NonZeroU64 ;
56use std:: thread;
67
78use fancy_regex:: Regex ;
89use pyo3:: exceptions;
910use pyo3:: prelude:: * ;
10- use pyo3:: types:: { PyBytes , PyList , PyTuple } ;
1111use pyo3:: PyResult ;
12+ use pyo3:: types:: { PyBytes , PyList , PyTuple } ;
1213use rustc_hash:: FxHashMap as HashMap ;
1314
15+ type Rank = u32 ;
16+
1417fn _byte_pair_merge < T > (
1518 piece : & [ u8 ] ,
16- ranks : & HashMap < Vec < u8 > , usize > ,
19+ ranks : & HashMap < Vec < u8 > , Rank > ,
1720 f : impl Fn ( std:: ops:: Range < usize > ) -> T ,
1821) -> Vec < T > {
1922 // This is a vector of (start, rank).
2023 // The rank is of the byte pair starting at position start.
2124 // The rank of the last item in the vector is not a valid value.
22- let mut parts: Vec < ( usize , usize ) > = ( 0 ..piece. len ( ) + 1 ) . map ( |i| ( i, usize :: MAX ) ) . collect ( ) ;
25+ let mut parts: Vec < ( usize , Rank ) > = ( 0 ..piece. len ( ) + 1 ) . map ( |i| ( i, Rank :: MAX ) ) . collect ( ) ;
2326
2427 let get_rank = {
2528 #[ inline( always) ]
26- |parts : & Vec < ( usize , usize ) > , start_idx : usize , skip : usize | {
29+ |parts : & Vec < ( usize , Rank ) > , start_idx : usize , skip : usize | {
2730 if ( start_idx + skip + 2 ) < parts. len ( ) {
2831 ranks
2932 . get ( & piece[ parts[ start_idx] . 0 ..parts[ start_idx + skip + 2 ] . 0 ] )
@@ -39,8 +42,8 @@ fn _byte_pair_merge<T>(
3942 for i in 0 ..parts. len ( ) - 2 {
4043 match get_rank ( & parts, i, 0 ) {
4144 Some ( rank) => {
42- // usize ::MAX is a sentinel value and cannot be a valid rank
43- debug_assert ! ( rank != usize :: MAX ) ;
45+ // Rank ::MAX is a sentinel value and cannot be a valid rank
46+ debug_assert ! ( rank != Rank :: MAX ) ;
4447 parts[ i] . 1 = rank;
4548 }
4649 None => {
@@ -63,26 +66,26 @@ fn _byte_pair_merge<T>(
6366 break ;
6467 }
6568
66- // usize ::MAX is a sentinel rank value allowing us to
69+ // Rank ::MAX is a sentinel rank value allowing us to
6770 // take the min more quickly
68- let mut min_rank: ( usize , usize ) = ( usize :: MAX , 0 ) ;
71+ let mut min_rank: ( Rank , usize ) = ( Rank :: MAX , 0 ) ;
6972 for ( i, & ( _, rank) ) in parts[ ..parts. len ( ) - 1 ] . iter ( ) . enumerate ( ) {
7073 if rank < min_rank. 0 {
7174 min_rank = ( rank, i) ;
7275 }
7376 }
7477
75- if min_rank. 0 != usize :: MAX {
78+ if min_rank. 0 != Rank :: MAX {
7679 let i = min_rank. 1 ;
7780
7881 // NOTE: We are about to remove parts[i + 1]. We do not do it
7982 // yet because there are cache-locality benefits to updating
8083 // parts[i] and parts[i-1] before removing, which could thrash
8184 // the cache. Thus, we update the rank calculation by skipping over
8285 // parts[i + 1], by invoking `get_rank!` with `skip = 1`.
83- parts[ i] . 1 = get_rank ( & parts, i, 1 ) . unwrap_or ( usize :: MAX ) ;
86+ parts[ i] . 1 = get_rank ( & parts, i, 1 ) . unwrap_or ( Rank :: MAX ) ;
8487 if i > 0 {
85- parts[ i - 1 ] . 1 = get_rank ( & parts, i - 1 , 1 ) . unwrap_or ( usize :: MAX ) ;
88+ parts[ i - 1 ] . 1 = get_rank ( & parts, i - 1 , 1 ) . unwrap_or ( Rank :: MAX ) ;
8689 }
8790
8891 parts. remove ( i + 1 ) ;
@@ -97,14 +100,14 @@ fn _byte_pair_merge<T>(
97100 out
98101}
99102
100- pub fn byte_pair_encode ( piece : & [ u8 ] , ranks : & HashMap < Vec < u8 > , usize > ) -> Vec < usize > {
103+ pub fn byte_pair_encode ( piece : & [ u8 ] , ranks : & HashMap < Vec < u8 > , Rank > ) -> Vec < Rank > {
101104 if piece. len ( ) == 1 {
102105 return vec ! [ ranks[ piece] ] ;
103106 }
104107 _byte_pair_merge ( piece, ranks, |p| ranks[ & piece[ p. start ..p. end ] ] )
105108}
106109
107- pub fn byte_pair_split < ' a > ( piece : & ' a [ u8 ] , ranks : & HashMap < Vec < u8 > , usize > ) -> Vec < & ' a [ u8 ] > {
110+ pub fn byte_pair_split < ' a > ( piece : & ' a [ u8 ] , ranks : & HashMap < Vec < u8 > , Rank > ) -> Vec < & ' a [ u8 ] > {
108111 if piece. len ( ) == 1 {
109112 return vec ! [ piece] ;
110113 }
@@ -152,7 +155,6 @@ pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, usize>) ->
152155// The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made
153156// to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster.
154157
155- use std:: num:: NonZeroU64 ;
156158pub struct FakeThreadId ( NonZeroU64 ) ;
157159
158160fn hash_current_thread ( ) -> usize {
@@ -169,12 +171,13 @@ fn hash_current_thread() -> usize {
169171}
170172
171173const MAX_NUM_THREADS : usize = 128 ;
174+
172175#[ pyclass]
173176struct CoreBPE {
174- encoder : HashMap < Vec < u8 > , usize > ,
175- special_tokens_encoder : HashMap < String , usize > ,
176- decoder : HashMap < usize , Vec < u8 > > ,
177- special_tokens_decoder : HashMap < usize , Vec < u8 > > ,
177+ encoder : HashMap < Vec < u8 > , Rank > ,
178+ special_tokens_encoder : HashMap < String , Rank > ,
179+ decoder : HashMap < Rank , Vec < u8 > > ,
180+ special_tokens_decoder : HashMap < Rank , Vec < u8 > > ,
178181 regex_tls : Vec < Regex > ,
179182 special_regex_tls : Vec < Regex > ,
180183 sorted_token_bytes : Vec < Vec < u8 > > ,
@@ -192,7 +195,7 @@ impl CoreBPE {
192195 & self . special_regex_tls [ hash_current_thread ( ) % MAX_NUM_THREADS ]
193196 }
194197
195- fn _decode_native ( & self , tokens : & [ usize ] ) -> Vec < u8 > {
198+ fn _decode_native ( & self , tokens : & [ Rank ] ) -> Vec < u8 > {
196199 let mut ret = Vec :: with_capacity ( tokens. len ( ) * 2 ) ;
197200 for token in tokens {
198201 let token_bytes = self
@@ -204,7 +207,7 @@ impl CoreBPE {
204207 ret
205208 }
206209
207- fn _encode_ordinary_native ( & self , text : & str ) -> Vec < usize > {
210+ fn _encode_ordinary_native ( & self , text : & str ) -> Vec < Rank > {
208211 // This is the core of the encoding logic; the other functions in here
209212 // just make things complicated :-)
210213 let regex = self . _get_tl_regex ( ) ;
@@ -220,7 +223,7 @@ impl CoreBPE {
220223 ret
221224 }
222225
223- fn _encode_native ( & self , text : & str , allowed_special : & HashSet < & str > ) -> ( Vec < usize > , usize ) {
226+ fn _encode_native ( & self , text : & str , allowed_special : & HashSet < & str > ) -> ( Vec < Rank > , usize ) {
224227 let special_regex = self . _get_tl_special_regex ( ) ;
225228 let regex = self . _get_tl_regex ( ) ;
226229 let mut ret = vec ! [ ] ;
@@ -278,9 +281,9 @@ impl CoreBPE {
278281
279282 fn _increase_last_piece_token_len (
280283 & self ,
281- tokens : Vec < usize > ,
284+ tokens : Vec < Rank > ,
282285 mut last_piece_token_len : usize ,
283- ) -> ( Vec < usize > , usize ) {
286+ ) -> ( Vec < Rank > , usize ) {
284287 // Unfortunately, the locations where our regex splits can be unstable.
285288 // For the purposes of determining unstable tokens, unstable regex splitting
286289 // is only a problem if a split that was present disappears, since this can
@@ -319,7 +322,7 @@ impl CoreBPE {
319322 & self ,
320323 text : & str ,
321324 allowed_special : & HashSet < & str > ,
322- ) -> ( Vec < usize > , HashSet < Vec < usize > > ) {
325+ ) -> ( Vec < Rank > , HashSet < Vec < Rank > > ) {
323326 let ( tokens, last_piece_token_len) = self . _encode_native ( text, allowed_special) ;
324327 if last_piece_token_len == 0 {
325328 // If last_piece_token_len is zero, the last token was a special token and we have
@@ -436,8 +439,8 @@ impl CoreBPE {
436439impl CoreBPE {
437440 #[ new]
438441 fn new (
439- encoder : HashMap < Vec < u8 > , usize > ,
440- special_tokens_encoder : HashMap < String , usize > ,
442+ encoder : HashMap < Vec < u8 > , Rank > ,
443+ special_tokens_encoder : HashMap < String , Rank > ,
441444 pattern : & str ,
442445 ) -> PyResult < Self > {
443446 let regex = Regex :: new ( pattern)
@@ -452,15 +455,15 @@ impl CoreBPE {
452455 . map_err ( |e| PyErr :: new :: < exceptions:: PyValueError , _ > ( e. to_string ( ) ) ) ?
453456 } ;
454457
455- let decoder: HashMap < usize , Vec < u8 > > =
458+ let decoder: HashMap < Rank , Vec < u8 > > =
456459 encoder. iter ( ) . map ( |( k, v) | ( * v, k. clone ( ) ) ) . collect ( ) ;
457460
458461 assert ! (
459462 encoder. len( ) == decoder. len( ) ,
460463 "Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?"
461464 ) ;
462465
463- let special_tokens_decoder: HashMap < usize , Vec < u8 > > = special_tokens_encoder
466+ let special_tokens_decoder: HashMap < Rank , Vec < u8 > > = special_tokens_encoder
464467 . iter ( )
465468 . map ( |( k, v) | ( * v, k. as_bytes ( ) . to_vec ( ) ) )
466469 . collect ( ) ;
@@ -486,15 +489,15 @@ impl CoreBPE {
486489 // Encoding
487490 // ====================
488491
489- fn encode_ordinary ( & self , py : Python , text : & str ) -> Vec < usize > {
492+ fn encode_ordinary ( & self , py : Python , text : & str ) -> Vec < Rank > {
490493 py. allow_threads ( || self . _encode_ordinary_native ( text) )
491494 }
492495
493- fn encode ( & self , py : Python , text : & str , allowed_special : HashSet < & str > ) -> Vec < usize > {
496+ fn encode ( & self , py : Python , text : & str , allowed_special : HashSet < & str > ) -> Vec < Rank > {
494497 py. allow_threads ( || self . _encode_native ( text, & allowed_special) . 0 )
495498 }
496499
497- fn _encode_bytes ( & self , py : Python , bytes : & [ u8 ] ) -> Vec < usize > {
500+ fn _encode_bytes ( & self , py : Python , bytes : & [ u8 ] ) -> Vec < Rank > {
498501 py. allow_threads ( || {
499502 match std:: str:: from_utf8 ( bytes) {
500503 Ok ( text) => self . _encode_ordinary_native ( text) ,
@@ -534,7 +537,7 @@ impl CoreBPE {
534537 ( tokens, py_completions) . into_py ( py)
535538 }
536539
537- fn encode_single_token ( & self , piece : & [ u8 ] ) -> PyResult < usize > {
540+ fn encode_single_token ( & self , piece : & [ u8 ] ) -> PyResult < Rank > {
538541 if let Some ( token) = self . encoder . get ( piece) . copied ( ) {
539542 return Ok ( token) ;
540543 }
@@ -546,7 +549,7 @@ impl CoreBPE {
546549 Err ( PyErr :: new :: < exceptions:: PyKeyError , _ > ( piece. to_owned ( ) ) )
547550 }
548551
549- fn encode_single_piece ( & self , piece : & [ u8 ] ) -> Vec < usize > {
552+ fn encode_single_piece ( & self , piece : & [ u8 ] ) -> Vec < Rank > {
550553 if let Some ( token) = self . encoder . get ( piece) {
551554 return vec ! [ * token] ;
552555 }
@@ -557,12 +560,12 @@ impl CoreBPE {
557560 // Decoding
558561 // ====================
559562
560- fn decode_bytes ( & self , py : Python , tokens : Vec < usize > ) -> Py < PyBytes > {
563+ fn decode_bytes ( & self , py : Python , tokens : Vec < Rank > ) -> Py < PyBytes > {
561564 let bytes = py. allow_threads ( || self . _decode_native ( & tokens) ) ;
562565 PyBytes :: new ( py, & bytes) . into ( )
563566 }
564567
565- fn decode_single_token_bytes ( & self , py : Python , token : usize ) -> PyResult < Py < PyBytes > > {
568+ fn decode_single_token_bytes ( & self , py : Python , token : Rank ) -> PyResult < Py < PyBytes > > {
566569 if let Some ( bytes) = self . decoder . get ( & token) {
567570 return Ok ( PyBytes :: new ( py, bytes) . into ( ) ) ;
568571 }
0 commit comments