|
1 |
| -use std::num::NonZeroU64; |
2 |
| -use std::thread; |
3 |
| - |
4 | 1 | use fancy_regex::Regex;
|
5 | 2 | use rustc_hash::FxHashMap as HashMap;
|
6 | 3 | use rustc_hash::FxHashSet as HashSet;
|
7 | 4 | use std::sync::Arc;
|
| 5 | +use thread_local::ThreadLocal; |
8 | 6 |
|
9 | 7 | pub type Rank = u32;
|
10 | 8 |
|
@@ -129,44 +127,26 @@ pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, Rank>) -> V
|
129 | 127 | // The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made
|
130 | 128 | // to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster.
|
131 | 129 |
|
132 |
| -pub struct FakeThreadId(NonZeroU64); |
133 |
| - |
134 |
| -fn hash_current_thread() -> usize { |
135 |
| - // It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter |
136 |
| - // that works great for our use case of avoiding collisions in our array. Unfortunately, |
137 |
| - // it's private. However, there are only so many ways you can layout a u64, so just transmute |
138 |
| - // https://github.com/rust-lang/rust/issues/67939 |
139 |
| - const _: [u8; 8] = [0; std::mem::size_of::<std::thread::ThreadId>()]; |
140 |
| - const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()]; |
141 |
| - let x = unsafe { |
142 |
| - std::mem::transmute::<std::thread::ThreadId, FakeThreadId>(thread::current().id()).0 |
143 |
| - }; |
144 |
| - u64::from(x) as usize |
145 |
| -} |
146 |
| - |
147 |
| -const MAX_NUM_THREADS: usize = 8; |
148 |
| - |
149 | 130 | #[derive(Debug)]
|
150 | 131 | pub struct CoreBPE {
|
151 | 132 | pub encoder: HashMap<Vec<u8>, Rank>,
|
152 | 133 | special_tokens_encoder: HashMap<String, Rank>,
|
153 | 134 | decoder: HashMap<Rank, &'static [u8]>,
|
154 | 135 | special_tokens_decoder: HashMap<Rank, Vec<u8>>,
|
155 |
| - regex_tls: Arc<[Regex]>, |
156 |
| - special_regex_tls: Arc<[Regex]>, |
| 136 | + regex: Regex, |
| 137 | + special_regex: Regex, |
| 138 | + regex_tls: ThreadLocal<Regex>, |
| 139 | + special_regex_tls: ThreadLocal<Regex>, |
157 | 140 | sorted_token_bytes: Vec<&'static [u8]>,
|
158 | 141 | }
|
159 | 142 |
|
160 | 143 | impl CoreBPE {
|
161 | 144 | fn _get_tl_regex(&self) -> &Regex {
|
162 |
| - // See performance notes above for what this is about |
163 |
| - // It's also a little janky, please make a better version of it! |
164 |
| - // However, it's nice that this doesn't leak memory to short-lived threads |
165 |
| - &self.regex_tls[hash_current_thread() % MAX_NUM_THREADS] |
| 145 | + self.regex_tls.get_or(|| self.regex.clone()) |
166 | 146 | }
|
167 | 147 |
|
168 | 148 | fn _get_tl_special_regex(&self) -> &Regex {
|
169 |
| - &self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS] |
| 149 | + self.special_regex_tls.get_or(|| self.special_regex.clone()) |
170 | 150 | }
|
171 | 151 |
|
172 | 152 | fn _decode_native(&self, tokens: &[Rank]) -> Vec<u8> {
|
@@ -460,16 +440,10 @@ impl CoreBPE {
|
460 | 440 | special_tokens_encoder,
|
461 | 441 | decoder,
|
462 | 442 | special_tokens_decoder,
|
463 |
| - regex_tls: Arc::from( |
464 |
| - (0..MAX_NUM_THREADS) |
465 |
| - .map(|_| regex.clone()) |
466 |
| - .collect::<Vec<_>>(), |
467 |
| - ), |
468 |
| - special_regex_tls: Arc::from( |
469 |
| - (0..MAX_NUM_THREADS) |
470 |
| - .map(|_| special_regex.clone()) |
471 |
| - .collect::<Vec<_>>(), |
472 |
| - ), |
| 443 | + regex, |
| 444 | + special_regex, |
| 445 | + regex_tls: ThreadLocal::new(), |
| 446 | + special_regex_tls: ThreadLocal::new(), |
473 | 447 | sorted_token_bytes,
|
474 | 448 | })
|
475 | 449 | }
|
|
0 commit comments