@@ -209,11 +209,10 @@ impl CoreBPE {
209
209
let mut ret = vec ! [ ] ;
210
210
for mat in regex. find_iter ( text) {
211
211
let piece = mat. unwrap ( ) . as_str ( ) . as_bytes ( ) ;
212
- if let Some ( token ) = self . encoder . get ( piece) {
213
- ret. push ( * token) ;
214
- continue ;
212
+ match self . encoder . get ( piece) {
213
+ Some ( token ) => ret. push ( * token) ,
214
+ None => ret . extend ( & byte_pair_encode ( piece , & self . encoder ) ) ,
215
215
}
216
- ret. extend ( & byte_pair_encode ( piece, & self . encoder ) ) ;
217
216
}
218
217
ret
219
218
}
@@ -525,7 +524,10 @@ impl CoreBPE {
525
524
unstable_bytes. extend_from_slice ( & bytes[ e. valid_up_to ( ) ..] ) ;
526
525
527
526
tokens. truncate ( tokens. len ( ) - last_piece_token_len) ;
528
- tokens. extend ( byte_pair_encode ( & unstable_bytes, & self . encoder ) ) ;
527
+ match self . encoder . get ( & unstable_bytes) {
528
+ Some ( token) => tokens. push ( * token) ,
529
+ None => tokens. extend ( & byte_pair_encode ( & unstable_bytes, & self . encoder ) ) ,
530
+ }
529
531
}
530
532
tokens
531
533
}
@@ -590,15 +592,26 @@ impl CoreBPE {
590
592
mod tests {
591
593
use rustc_hash:: FxHashMap as HashMap ;
592
594
593
- use crate :: corebpe:: byte_pair_split;
595
+ use crate :: corebpe:: { byte_pair_split, Rank } ;
594
596
595
- #[ test]
596
- fn very_simple_test ( ) {
597
- let mut ranks = HashMap :: default ( ) ;
598
- ranks. insert ( b"ab" . to_vec ( ) , 1 ) ;
599
- ranks. insert ( b"cd" . to_vec ( ) , 2 ) ;
597
+ fn setup_ranks ( ) -> HashMap < Vec < u8 > , Rank > {
598
+ HashMap :: from_iter ( [
599
+ ( b"ab" . to_vec ( ) , 0 ) ,
600
+ ( b"cd" . to_vec ( ) , 1 ) ,
601
+ ] )
602
+ }
600
603
604
+ #[ test]
605
+ fn test_simple_characters ( ) {
606
+ let ranks = setup_ranks ( ) ;
601
607
let res = byte_pair_split ( b"abcd" , & ranks) ;
602
608
assert_eq ! ( res, vec![ b"ab" , b"cd" ] ) ;
603
609
}
610
+ #[ test]
611
+ fn test_repeated_characters ( ) {
612
+ let ranks = setup_ranks ( ) ;
613
+ let res = byte_pair_split ( b"abab" , & ranks) ;
614
+ assert_eq ! ( res, vec![ b"ab" , b"ab" ] ) ;
615
+ }
616
+
604
617
}
0 commit comments