Skip to content

Commit 615bf1f

Browse files
authored
remove Arc<> for encoder hashmap and use directly (#12)
partial revert of #11 was seeing weird perf regression in benchmarks rust borrow checker makes it really hard for two things to own the same data, so take the simple way out by letting CoreBPE own mergeable_ranks
1 parent 617cd6a commit 615bf1f

File tree

3 files changed

+8
-12
lines changed

3 files changed

+8
-12
lines changed

src/corebpe.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ const MAX_NUM_THREADS: usize = 8;
166166

167167
#[derive(Debug)]
168168
pub struct CoreBPE {
169-
encoder: Arc<HashMap<Vec<u8>, usize>>,
169+
pub encoder: HashMap<Vec<u8>, usize>,
170170
special_tokens_encoder: HashMap<String, usize>,
171171
decoder: HashMap<usize, Vec<u8>>,
172172
special_tokens_decoder: HashMap<usize, Vec<u8>>,
@@ -429,7 +429,7 @@ impl CoreBPE {
429429

430430
impl CoreBPE {
431431
pub fn new(
432-
encoder: Arc<HashMap<Vec<u8>, usize>>,
432+
encoder: HashMap<Vec<u8>, usize>,
433433
special_tokens_encoder: HashMap<String, usize>,
434434
pattern: &str,
435435
) -> Result<Self, fancy_regex::Error> {

src/encoding.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@ pub struct Encoding {
1313
pub name: String,
1414
/// The regular expression pattern used to split text into pieces.
1515
pat_str: String,
16-
/// The map from mergeable byte sequences to their ranks.
17-
mergeable_ranks: Arc<HashMap<Vec<u8>, usize>>,
1816
/// The maximum length of the keys in `mergeable_ranks`.
1917
mergeable_ranks_max_key_len: usize,
2018
/// All prefixes of the mergeable ranks. May or may not be tokens themselves!
@@ -64,7 +62,7 @@ impl Encoding {
6462
pub fn new(
6563
name: &str,
6664
pat_str: &str,
67-
mergeable_ranks: Arc<HashMap<Vec<u8>, usize>>,
65+
mergeable_ranks: HashMap<Vec<u8>, usize>,
6866
special_tokens: HashMap<String, usize>,
6967
explicit_n_vocab: Option<usize>,
7068
) -> Result<Self, EncodingError> {
@@ -113,7 +111,6 @@ impl Encoding {
113111
Ok(Self {
114112
name: name.to_string(),
115113
pat_str: pat_str.to_string(),
116-
mergeable_ranks,
117114
mergeable_ranks_max_key_len,
118115
prefixes_of_mergeable_ranks,
119116
special_tokens,
@@ -157,7 +154,7 @@ impl Encoding {
157154
if current_token.len() > 1 {
158155
new_current_token.clear();
159156
new_current_token.push(current_token.pop().unwrap());
160-
while !self.mergeable_ranks.contains_key(&current_token) {
157+
while !self.core_bpe.encoder.contains_key(&current_token) {
161158
if current_token.len() == 1 {
162159
break;
163160
}
@@ -177,14 +174,14 @@ impl Encoding {
177174
}
178175
}
179176

180-
while !self.mergeable_ranks.contains_key(&current_token) {
177+
while !self.core_bpe.encoder.contains_key(&current_token) {
181178
if current_token.len() == 0 {
182179
break;
183180
}
184181
if current_token.len() > 1 {
185182
new_current_token.clear();
186183
new_current_token.push(current_token.pop().unwrap());
187-
while !self.mergeable_ranks.contains_key(&current_token) {
184+
while !self.core_bpe.encoder.contains_key(&current_token) {
188185
if current_token.len() == 1 {
189186
break;
190187
}

src/load.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ use sha2::Sha256;
55
// call its methods without adding to the namespace.
66
use base64::engine::general_purpose::STANDARD as BASE64;
77
use base64::engine::Engine as _;
8-
use std::sync::Arc;
98

109
// define the error
1110
#[derive(Debug, Clone)]
@@ -17,7 +16,7 @@ pub enum Error {
1716
pub fn load_tiktoken_bpe(
1817
tiktoken_bpe_contents: &[u8],
1918
shasum: &str,
20-
) -> Result<Arc<HashMap<Vec<u8>, usize>>, Error> {
19+
) -> Result<HashMap<Vec<u8>, usize>, Error> {
2120
// check the shasum
2221
let mut hasher = Sha256::new();
2322
hasher.update(tiktoken_bpe_contents);
@@ -43,5 +42,5 @@ pub fn load_tiktoken_bpe(
4342
map.insert(token, rank);
4443
}
4544
map.shrink_to_fit();
46-
Ok(Arc::new(map))
45+
Ok(map)
4746
}

0 commit comments

Comments
 (0)