@@ -4,8 +4,12 @@ use rustc_hash::FxHashMap as HashMap;
4
4
use rustc_hash:: FxHashSet as HashSet ;
5
5
use std:: sync:: Arc ;
6
6
use thiserror:: Error ;
7
+ use odht:: HashTableOwned ;
7
8
use crate :: rollhash:: { roll_hash, roll_hash_slice} ;
8
9
10
+ include ! ( "odht.rs" ) ;
11
+ include ! ( concat!( env!( "OUT_DIR" ) , "/static.rs" ) ) ;
12
+
9
13
/// A struct that represents an encoding scheme based on byte-pair encoding (BPE).
10
14
#[ derive( Debug ) ]
11
15
pub struct Encoding {
@@ -16,7 +20,7 @@ pub struct Encoding {
16
20
/// The maximum length of the keys in `mergeable_ranks`.
17
21
mergeable_ranks_max_key_len : usize ,
18
22
/// All prefixes of the mergeable ranks. May or may not be tokens themselves!
19
- prefixes_of_mergeable_ranks : HashSet < i64 > ,
23
+ prefixes_of_mergeable_ranks : HashTableOwned < PrefixConfig > ,
20
24
/// The map from special token strings to their values.
21
25
special_tokens : HashMap < String , usize > ,
22
26
/// The maximum token value in the encoding.
@@ -97,16 +101,18 @@ impl Encoding {
97
101
)
98
102
. map_err ( |e| EncodingError :: GenericEncodingError ( format ! ( "Error creating core BPE: {}" , e) ) ) ?;
99
103
100
- let mut prefixes_of_mergeable_ranks = mergeable_ranks
101
- . keys ( )
102
- . flat_map ( |bytes| {
103
- ( 1 ..=bytes. len ( ) )
104
- . map ( |i| roll_hash_slice ( & bytes[ ..i] ) )
105
- . collect :: < Vec < _ > > ( )
104
+ let prefixes_of_mergeable_ranks = unsafe {
105
+ HashTableOwned :: < PrefixConfig > :: from_raw_bytes_unchecked ( match name {
106
+ "r50k_base" => data:: R50K_BASE_PREFIXES_ODHT ,
107
+ "p50k_base" => data:: P50K_BASE_PREFIXES_ODHT ,
108
+ "cl100k_base" => data:: CL100K_BASE_PREFIXES_ODHT ,
109
+ "o200k_base" => data:: O200K_BASE_PREFIXES_ODHT ,
110
+ "codestral" => data:: CODESTRAL_PREFIXES_ODHT ,
111
+ "llama3" => data:: LLAMA3_PREFIXES_ODHT ,
112
+ "deepseekv2" => data:: DEEPSEEKV2_PREFIXES_ODHT ,
113
+ _ => return Err ( EncodingError :: GenericEncodingError ( format ! ( "Embedded prefix table not found for encoding: {}" , name) ) ) ,
106
114
} )
107
- . collect :: < HashSet < _ > > ( ) ;
108
- prefixes_of_mergeable_ranks. insert ( 0 ) ;
109
- prefixes_of_mergeable_ranks. shrink_to_fit ( ) ;
115
+ } ;
110
116
111
117
Ok ( Self {
112
118
name : name. to_string ( ) ,
@@ -148,7 +154,7 @@ impl Encoding {
148
154
// or if the current token is not in the prefixes of mergeable ranks,
149
155
// we need to split the current token and begin actually checking for the largest
150
156
// mergeable prefix
151
- while !self . prefixes_of_mergeable_ranks . contains ( & current_token_hash)
157
+ while !self . prefixes_of_mergeable_ranks . contains_key ( & current_token_hash)
152
158
|| current_token. len ( ) > self . mergeable_ranks_max_key_len
153
159
{
154
160
if current_token. len ( ) > 1 {
0 commit comments