Skip to content

Commit 7ba8ad9

Browse files
authored
use real thread-local storage (#20)
upstream code in openai/tiktoken is wrapped with PyO3 so they're concerned about short-lived python-land threads eating up memory in our case, we have a fixed actor-thread-pool dedicated to tokenization so we don't need extra copies and hash collisions getting in the way
1 parent 2a6523f commit 7ba8ad9

File tree

3 files changed

+23
-37
lines changed

3 files changed

+23
-37
lines changed

Cargo.lock

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ base64 = "0.21.0"
3333
thiserror = "1.0.38"
3434
const-primes = "0.8.7"
3535
odht = "0.3.1"
36+
thread_local = "1.1.8"
3637

3738
[[bench]]
3839
name = "bench"

src/corebpe.rs

Lines changed: 11 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
use std::num::NonZeroU64;
2-
use std::thread;
3-
41
use fancy_regex::Regex;
52
use rustc_hash::FxHashMap as HashMap;
63
use rustc_hash::FxHashSet as HashSet;
74
use std::sync::Arc;
5+
use thread_local::ThreadLocal;
86

97
pub type Rank = u32;
108

@@ -129,44 +127,26 @@ pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, Rank>) -> V
129127
// The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made
130128
// to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster.
131129

132-
pub struct FakeThreadId(NonZeroU64);
133-
134-
fn hash_current_thread() -> usize {
135-
// It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter
136-
// that works great for our use case of avoiding collisions in our array. Unfortunately,
137-
// it's private. However, there are only so many ways you can layout a u64, so just transmute
138-
// https://github.com/rust-lang/rust/issues/67939
139-
const _: [u8; 8] = [0; std::mem::size_of::<std::thread::ThreadId>()];
140-
const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()];
141-
let x = unsafe {
142-
std::mem::transmute::<std::thread::ThreadId, FakeThreadId>(thread::current().id()).0
143-
};
144-
u64::from(x) as usize
145-
}
146-
147-
const MAX_NUM_THREADS: usize = 8;
148-
149130
#[derive(Debug)]
150131
pub struct CoreBPE {
151132
pub encoder: HashMap<Vec<u8>, Rank>,
152133
special_tokens_encoder: HashMap<String, Rank>,
153134
decoder: HashMap<Rank, &'static [u8]>,
154135
special_tokens_decoder: HashMap<Rank, Vec<u8>>,
155-
regex_tls: Arc<[Regex]>,
156-
special_regex_tls: Arc<[Regex]>,
136+
regex: Regex,
137+
special_regex: Regex,
138+
regex_tls: ThreadLocal<Regex>,
139+
special_regex_tls: ThreadLocal<Regex>,
157140
sorted_token_bytes: Vec<&'static [u8]>,
158141
}
159142

160143
impl CoreBPE {
161144
fn _get_tl_regex(&self) -> &Regex {
162-
// See performance notes above for what this is about
163-
// It's also a little janky, please make a better version of it!
164-
// However, it's nice that this doesn't leak memory to short-lived threads
165-
&self.regex_tls[hash_current_thread() % MAX_NUM_THREADS]
145+
self.regex_tls.get_or(|| self.regex.clone())
166146
}
167147

168148
fn _get_tl_special_regex(&self) -> &Regex {
169-
&self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS]
149+
self.special_regex_tls.get_or(|| self.special_regex.clone())
170150
}
171151

172152
fn _decode_native(&self, tokens: &[Rank]) -> Vec<u8> {
@@ -460,16 +440,10 @@ impl CoreBPE {
460440
special_tokens_encoder,
461441
decoder,
462442
special_tokens_decoder,
463-
regex_tls: Arc::from(
464-
(0..MAX_NUM_THREADS)
465-
.map(|_| regex.clone())
466-
.collect::<Vec<_>>(),
467-
),
468-
special_regex_tls: Arc::from(
469-
(0..MAX_NUM_THREADS)
470-
.map(|_| special_regex.clone())
471-
.collect::<Vec<_>>(),
472-
),
443+
regex,
444+
special_regex,
445+
regex_tls: ThreadLocal::new(),
446+
special_regex_tls: ThreadLocal::new(),
473447
sorted_token_bytes,
474448
})
475449
}

0 commit comments

Comments
 (0)