|
1 | | -use std::num::NonZeroU64; |
2 | | -use std::thread; |
3 | | - |
4 | 1 | use fancy_regex::Regex; |
5 | 2 | use rustc_hash::FxHashMap as HashMap; |
6 | 3 | use rustc_hash::FxHashSet as HashSet; |
7 | 4 | use std::sync::Arc; |
| 5 | +use thread_local::ThreadLocal; |
8 | 6 |
|
9 | 7 | pub type Rank = u32; |
10 | 8 |
|
@@ -129,44 +127,26 @@ pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, Rank>) -> V |
129 | 127 | // The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made |
130 | 128 | // to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster. |
131 | 129 |
|
132 | | -pub struct FakeThreadId(NonZeroU64); |
133 | | - |
134 | | -fn hash_current_thread() -> usize { |
135 | | - // It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter |
136 | | - // that works great for our use case of avoiding collisions in our array. Unfortunately, |
137 | | - // it's private. However, there are only so many ways you can layout a u64, so just transmute |
138 | | - // https://github.com/rust-lang/rust/issues/67939 |
139 | | - const _: [u8; 8] = [0; std::mem::size_of::<std::thread::ThreadId>()]; |
140 | | - const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()]; |
141 | | - let x = unsafe { |
142 | | - std::mem::transmute::<std::thread::ThreadId, FakeThreadId>(thread::current().id()).0 |
143 | | - }; |
144 | | - u64::from(x) as usize |
145 | | -} |
146 | | - |
147 | | -const MAX_NUM_THREADS: usize = 8; |
148 | | - |
149 | 130 | #[derive(Debug)] |
150 | 131 | pub struct CoreBPE { |
151 | 132 | pub encoder: HashMap<Vec<u8>, Rank>, |
152 | 133 | special_tokens_encoder: HashMap<String, Rank>, |
153 | 134 | decoder: HashMap<Rank, &'static [u8]>, |
154 | 135 | special_tokens_decoder: HashMap<Rank, Vec<u8>>, |
155 | | - regex_tls: Arc<[Regex]>, |
156 | | - special_regex_tls: Arc<[Regex]>, |
| 136 | + regex: Regex, |
| 137 | + special_regex: Regex, |
| 138 | + regex_tls: ThreadLocal<Regex>, |
| 139 | + special_regex_tls: ThreadLocal<Regex>, |
157 | 140 | sorted_token_bytes: Vec<&'static [u8]>, |
158 | 141 | } |
159 | 142 |
|
160 | 143 | impl CoreBPE { |
161 | 144 | fn _get_tl_regex(&self) -> &Regex { |
162 | | - // See performance notes above for what this is about |
163 | | - // It's also a little janky, please make a better version of it! |
164 | | - // However, it's nice that this doesn't leak memory to short-lived threads |
165 | | - &self.regex_tls[hash_current_thread() % MAX_NUM_THREADS] |
| 145 | + self.regex_tls.get_or(|| self.regex.clone()) |
166 | 146 | } |
167 | 147 |
|
168 | 148 | fn _get_tl_special_regex(&self) -> &Regex { |
169 | | - &self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS] |
| 149 | + self.special_regex_tls.get_or(|| self.special_regex.clone()) |
170 | 150 | } |
171 | 151 |
|
172 | 152 | fn _decode_native(&self, tokens: &[Rank]) -> Vec<u8> { |
@@ -460,16 +440,10 @@ impl CoreBPE { |
460 | 440 | special_tokens_encoder, |
461 | 441 | decoder, |
462 | 442 | special_tokens_decoder, |
463 | | - regex_tls: Arc::from( |
464 | | - (0..MAX_NUM_THREADS) |
465 | | - .map(|_| regex.clone()) |
466 | | - .collect::<Vec<_>>(), |
467 | | - ), |
468 | | - special_regex_tls: Arc::from( |
469 | | - (0..MAX_NUM_THREADS) |
470 | | - .map(|_| special_regex.clone()) |
471 | | - .collect::<Vec<_>>(), |
472 | | - ), |
| 443 | + regex, |
| 444 | + special_regex, |
| 445 | + regex_tls: ThreadLocal::new(), |
| 446 | + special_regex_tls: ThreadLocal::new(), |
473 | 447 | sorted_token_bytes, |
474 | 448 | }) |
475 | 449 | } |
|
0 commit comments