Skip to content

Commit 48e7eda

Browse files
Lőrinctmm1
authored andcommitted
Avoid calling byte_pair_encode for existing tokens
This was byte_pair_encode can be optimized further, assuming we'll always have at least 2 tokens (cherry-picked from openai/tiktoken@b4c687e)
1 parent e3c7845 commit 48e7eda

File tree

1 file changed

+24
-11
lines changed

1 file changed

+24
-11
lines changed

src/corebpe.rs

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -209,11 +209,10 @@ impl CoreBPE {
209209
let mut ret = vec![];
210210
for mat in regex.find_iter(text) {
211211
let piece = mat.unwrap().as_str().as_bytes();
212-
if let Some(token) = self.encoder.get(piece) {
213-
ret.push(*token);
214-
continue;
212+
match self.encoder.get(piece) {
213+
Some(token) => ret.push(*token),
214+
None => ret.extend(&byte_pair_encode(piece, &self.encoder)),
215215
}
216-
ret.extend(&byte_pair_encode(piece, &self.encoder));
217216
}
218217
ret
219218
}
@@ -525,7 +524,10 @@ impl CoreBPE {
525524
unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]);
526525

527526
tokens.truncate(tokens.len() - last_piece_token_len);
528-
tokens.extend(byte_pair_encode(&unstable_bytes, &self.encoder));
527+
match self.encoder.get(&unstable_bytes) {
528+
Some(token) => tokens.push(*token),
529+
None => tokens.extend(&byte_pair_encode(&unstable_bytes, &self.encoder)),
530+
}
529531
}
530532
tokens
531533
}
@@ -590,15 +592,26 @@ impl CoreBPE {
590592
mod tests {
591593
use rustc_hash::FxHashMap as HashMap;
592594

593-
use crate::corebpe::byte_pair_split;
595+
use crate::corebpe::{byte_pair_split, Rank};
594596

595-
#[test]
596-
fn very_simple_test() {
597-
let mut ranks = HashMap::default();
598-
ranks.insert(b"ab".to_vec(), 1);
599-
ranks.insert(b"cd".to_vec(), 2);
597+
fn setup_ranks() -> HashMap<Vec<u8>, Rank> {
598+
HashMap::from_iter([
599+
(b"ab".to_vec(), 0),
600+
(b"cd".to_vec(), 1),
601+
])
602+
}
600603

604+
#[test]
605+
fn test_simple_characters() {
606+
let ranks = setup_ranks();
601607
let res = byte_pair_split(b"abcd", &ranks);
602608
assert_eq!(res, vec![b"ab", b"cd"]);
603609
}
610+
#[test]
611+
fn test_repeated_characters() {
612+
let ranks = setup_ranks();
613+
let res = byte_pair_split(b"abab", &ranks);
614+
assert_eq!(res, vec![b"ab", b"ab"]);
615+
}
616+
604617
}

0 commit comments

Comments
 (0)