Skip to content

Commit 300ec0a

Browse files
authored
embed precomputed odht tables for prefixes_of_mergeable_ranks (#13)
this provies a ~7% increase in count_tokens performance on intel processors the odht table lives in shared memory inside the mmap'd .so file, so memory usage is improved as well also speeds up runtime boot, since the table is now pregenerated
1 parent 615bf1f commit 300ec0a

File tree

5 files changed

+159
-11
lines changed

5 files changed

+159
-11
lines changed

Cargo.lock

Lines changed: 10 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ name = "tiktoken"
33
version = "0.1.0"
44
edition = "2021"
55
rust-version = "1.57.0"
6+
build = "build.rs"
67

78
[lib]
89
name = "tiktoken"
@@ -13,6 +14,14 @@ tiktoken-rs = "0.5.4"
1314
memory-stats = "1.2.0"
1415
test-case = "2.0.0-rc3"
1516

17+
[build-dependencies]
18+
hex = "0.4.3"
19+
sha2 = "0.10.6"
20+
base64 = "0.21.0"
21+
rustc-hash = "1.1.0"
22+
odht = "0.3.1"
23+
const-primes = "0.8.7"
24+
1625
[dependencies]
1726
fancy-regex = "0.10.0"
1827
regex = "1.7.0"
@@ -23,6 +32,7 @@ sha2 = "0.10.6"
2332
base64 = "0.21.0"
2433
thiserror = "1.0.38"
2534
const-primes = "0.8.7"
35+
odht = "0.3.1"
2636

2737
[[bench]]
2838
name = "bench"

build.rs

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
include!("src/load.rs");
2+
include!("src/rollhash.rs");
3+
include!("src/odht.rs");
4+
5+
use std::env;
6+
use std::fs::File;
7+
use std::io::Write;
8+
use std::path::Path;
9+
use odht::HashTableOwned;
10+
use rustc_hash::FxHashSet as HashSet;
11+
12+
fn main() {
13+
println!("cargo:rerun-if-changed=build.rs");
14+
println!("cargo:rerun-if-changed=src/load.rs");
15+
16+
let out_dir = env::var("OUT_DIR").unwrap();
17+
let mut file = File::create(&Path::new(&out_dir).join("static.rs")).unwrap();
18+
writeln!(file, "pub mod data {{").unwrap();
19+
20+
generate("r50k_base",
21+
&mut file,
22+
&load_tiktoken_bpe(
23+
include_bytes!("data/r50k_base.tiktoken"),
24+
"306cd27f03c1a714eca7108e03d66b7dc042abe8c258b44c199a7ed9838dd930",
25+
).unwrap());
26+
27+
generate("p50k_base",
28+
&mut file,
29+
&load_tiktoken_bpe(
30+
include_bytes!("data/p50k_base.tiktoken"),
31+
"94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069",
32+
).unwrap());
33+
34+
generate("cl100k_base",
35+
&mut file,
36+
&load_tiktoken_bpe(
37+
include_bytes!("data/cl100k_base.tiktoken"),
38+
"223921b76ee99bde995b7ff738513eef100fb51d18c93597a113bcffe865b2a7",
39+
).unwrap());
40+
41+
generate("o200k_base",
42+
&mut file,
43+
&load_tiktoken_bpe(
44+
include_bytes!("data/o200k_base.tiktoken"),
45+
"446a9538cb6c348e3516120d7c08b09f57c36495e2acfffe59a5bf8b0cfb1a2d",
46+
).unwrap());
47+
48+
generate("codestral",
49+
&mut file,
50+
&load_tiktoken_bpe(
51+
include_bytes!("data/codestral.tiktoken"),
52+
"bd5e66af07259851e88c3e483f88371dc2408cb0ce8b9787d29eaecdbb78eade",
53+
).unwrap());
54+
55+
generate("llama3",
56+
&mut file,
57+
&load_tiktoken_bpe(
58+
include_bytes!("data/llama3.tiktoken"),
59+
"82e9d31979e92ab929cd544440f129d9ecd797b69e327f80f17e1c50d5551b55",
60+
).unwrap());
61+
62+
generate("deepseekv2",
63+
&mut file,
64+
&load_tiktoken_bpe(
65+
include_bytes!("data/deepseekv2.tiktoken"),
66+
"3516b4e6e24389f7d1b288d861ce063da13296f916d29384e56ea9e0f6ba6674",
67+
).unwrap());
68+
69+
writeln!(file, "}}").unwrap();
70+
}
71+
72+
fn generate(
73+
name: &str,
74+
file: &mut File,
75+
mergeable_ranks: &HashMap<Vec<u8>, usize>,
76+
) {
77+
writeln!(
78+
file,
79+
" pub const {}_PREFIXES_ODHT: &'static [u8] = include_bytes!(\"{}.prefixes.odht\");",
80+
name.to_uppercase(),
81+
name
82+
).unwrap();
83+
84+
let mut prefixes_of_mergeable_ranks = mergeable_ranks
85+
.keys()
86+
.flat_map(|bytes| {
87+
(1..=bytes.len())
88+
.map(|i| roll_hash_slice(&bytes[..i]))
89+
.collect::<Vec<_>>()
90+
})
91+
.collect::<HashSet<_>>();
92+
prefixes_of_mergeable_ranks.insert(0);
93+
prefixes_of_mergeable_ranks.shrink_to_fit();
94+
95+
let mut odht = HashTableOwned::<PrefixConfig>::with_capacity(prefixes_of_mergeable_ranks.len(), 50);
96+
for prefix in prefixes_of_mergeable_ranks {
97+
odht.insert(&prefix, &());
98+
}
99+
100+
let mut file = File::create(format!("{}/{}.prefixes.odht", env::var("OUT_DIR").unwrap(), name)).unwrap();
101+
file.write_all(odht.raw_bytes()).unwrap();
102+
}

src/encoding.rs

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,12 @@ use rustc_hash::FxHashMap as HashMap;
44
use rustc_hash::FxHashSet as HashSet;
55
use std::sync::Arc;
66
use thiserror::Error;
7+
use odht::HashTableOwned;
78
use crate::rollhash::{roll_hash, roll_hash_slice};
89

10+
include!("odht.rs");
11+
include!(concat!(env!("OUT_DIR"), "/static.rs"));
12+
913
/// A struct that represents an encoding scheme based on byte-pair encoding (BPE).
1014
#[derive(Debug)]
1115
pub struct Encoding {
@@ -16,7 +20,7 @@ pub struct Encoding {
1620
/// The maximum length of the keys in `mergeable_ranks`.
1721
mergeable_ranks_max_key_len: usize,
1822
/// All prefixes of the mergeable ranks. May or may not be tokens themselves!
19-
prefixes_of_mergeable_ranks: HashSet<i64>,
23+
prefixes_of_mergeable_ranks: HashTableOwned<PrefixConfig>,
2024
/// The map from special token strings to their values.
2125
special_tokens: HashMap<String, usize>,
2226
/// The maximum token value in the encoding.
@@ -97,16 +101,18 @@ impl Encoding {
97101
)
98102
.map_err(|e| EncodingError::GenericEncodingError(format!("Error creating core BPE: {}", e)))?;
99103

100-
let mut prefixes_of_mergeable_ranks = mergeable_ranks
101-
.keys()
102-
.flat_map(|bytes| {
103-
(1..=bytes.len())
104-
.map(|i| roll_hash_slice(&bytes[..i]))
105-
.collect::<Vec<_>>()
104+
let prefixes_of_mergeable_ranks = unsafe {
105+
HashTableOwned::<PrefixConfig>::from_raw_bytes_unchecked(match name {
106+
"r50k_base" => data::R50K_BASE_PREFIXES_ODHT,
107+
"p50k_base" => data::P50K_BASE_PREFIXES_ODHT,
108+
"cl100k_base" => data::CL100K_BASE_PREFIXES_ODHT,
109+
"o200k_base" => data::O200K_BASE_PREFIXES_ODHT,
110+
"codestral" => data::CODESTRAL_PREFIXES_ODHT,
111+
"llama3" => data::LLAMA3_PREFIXES_ODHT,
112+
"deepseekv2" => data::DEEPSEEKV2_PREFIXES_ODHT,
113+
_ => return Err(EncodingError::GenericEncodingError(format!("Embedded prefix table not found for encoding: {}", name))),
106114
})
107-
.collect::<HashSet<_>>();
108-
prefixes_of_mergeable_ranks.insert(0);
109-
prefixes_of_mergeable_ranks.shrink_to_fit();
115+
};
110116

111117
Ok(Self {
112118
name: name.to_string(),
@@ -148,7 +154,7 @@ impl Encoding {
148154
// or if the current token is not in the prefixes of mergeable ranks,
149155
// we need to split the current token and begin actually checking for the largest
150156
// mergeable prefix
151-
while !self.prefixes_of_mergeable_ranks.contains(&current_token_hash)
157+
while !self.prefixes_of_mergeable_ranks.contains_key(&current_token_hash)
152158
|| current_token.len() > self.mergeable_ranks_max_key_len
153159
{
154160
if current_token.len() > 1 {

src/odht.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
use odht::{Config, FxHashFn};
2+
3+
pub struct PrefixConfig;
4+
5+
impl Config for PrefixConfig {
6+
type Key = i64;
7+
type Value = ();
8+
type EncodedKey = [u8; 8];
9+
type EncodedValue = [u8; 0];
10+
type H = FxHashFn;
11+
12+
#[inline(always)]
13+
fn encode_key(k: &Self::Key) -> Self::EncodedKey { k.to_le_bytes() }
14+
#[inline(always)]
15+
fn encode_value(_: &Self::Value) -> Self::EncodedValue { [] }
16+
#[inline(always)]
17+
fn decode_key(k: &Self::EncodedKey) -> Self::Key { i64::from_le_bytes(*k) }
18+
#[inline(always)]
19+
fn decode_value(_: &Self::EncodedValue) -> Self::Value { () }
20+
}

0 commit comments

Comments
 (0)