Skip to content

Commit 9f1384f

Browse files
committed
fix(tokenizer): 实现特殊序列先匹配
Signed-off-by: YdrMaster <[email protected]>
1 parent f688768 commit 9f1384f

File tree

7 files changed

+168
-20
lines changed

7 files changed

+168
-20
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

service/src/lib.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use std::{
1111
path::Path,
1212
sync::Arc,
1313
};
14-
use tokenizer::{BPECommonNormalizer, Normalizer, Tokenizer, VocabTxt, BPE};
14+
use tokenizer::{BPECommonNormalizer, Normalizer, Tokenize, Tokenizer, VocabTxt, BPE};
1515
use tokio::task::JoinHandle;
1616

1717
pub use chat_template::Message;
@@ -29,7 +29,7 @@ pub struct Service<M: CausalLM> {
2929
/// 推理线程的生命周期与这个组件绑定。
3030
struct ServiceComponent<M: CausalLM> {
3131
handle: Arc<Dispatcher<M>>,
32-
tokenizer: Box<dyn Tokenizer + Send + Sync>,
32+
tokenizer: Box<dyn Tokenize + Send + Sync>,
3333
normalizer: Box<dyn Normalizer + Send + Sync>,
3434
template: ChatTemplate,
3535
bos: String,
@@ -165,10 +165,10 @@ fn normalizer(model_dir: impl AsRef<Path>) -> Box<dyn Normalizer + Send + Sync>
165165
panic!("Tokenizer file not found");
166166
}
167167

168-
fn tokenizer(model_dir: impl AsRef<Path>) -> Box<dyn Tokenizer + Send + Sync> {
168+
fn tokenizer(model_dir: impl AsRef<Path>) -> Box<dyn Tokenize + Send + Sync> {
169169
use std::io::ErrorKind::NotFound;
170170
match BPE::from_tokenizer_model(model_dir.as_ref().join("tokenizer.model")) {
171-
Ok(bpe) => return Box::new(bpe),
171+
Ok(bpe) => return Box::new(Tokenizer::new(bpe)),
172172
Err(e) if e.kind() == NotFound => {}
173173
Err(e) => panic!("{e:?}"),
174174
}

tokenizer/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ authors = ["YdrMaster <[email protected]>"]
99
[dependencies]
1010
memmap2.workspace = true
1111
patricia_tree = "0.8"
12+
regex = "1.10"

tokenizer/src/bpe.rs

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::{as_byte_token, utok, Tokenizer};
1+
use crate::{as_byte_token, utok, Method};
22
use std::{
33
collections::{HashMap, HashSet},
44
io,
@@ -161,7 +161,7 @@ impl BPE {
161161
.iter()
162162
.filter_map(|&t| {
163163
let s = unsafe { std::str::from_utf8_unchecked(self.token(t)) };
164-
if self.encode(s).len() > 1 {
164+
if self.encode(s).into_iter().nth(1).is_some() {
165165
Some((s, t))
166166
} else {
167167
None
@@ -192,22 +192,28 @@ impl BPE {
192192
}
193193
}
194194

195-
impl Tokenizer for BPE {
195+
impl Method for BPE {
196+
#[inline]
197+
fn unk_token(&self) -> utok {
198+
self.unk
199+
}
196200
#[inline]
197201
fn vocab_size(&self) -> usize {
198202
self.tokens.len()
199203
}
200-
201204
#[inline]
202-
fn encode(&self, text: &str) -> Vec<utok> {
205+
fn internal_special(&self) -> impl IntoIterator<Item = (&str, utok)> {
206+
self.inaccessible()
207+
}
208+
#[inline]
209+
fn encode<'a>(&'a self, text: &'a str) -> impl IntoIterator<Item = utok> + 'a {
203210
let mut tokenizer = self.build_tokenizer(text);
204211
while tokenizer.merge() {}
205-
tokenizer.iter().collect()
212+
tokenizer
206213
}
207-
208214
#[inline]
209-
fn decode(&self, token: utok) -> &str {
210-
unsafe { std::str::from_utf8_unchecked(self.token(token)) }
215+
fn decode(&self, token: utok) -> &[u8] {
216+
self.token(token)
211217
}
212218
}
213219

@@ -267,9 +273,15 @@ mod algorithm {
267273
merges: BinaryHeap<Merge>,
268274
}
269275

276+
pub struct IntoIter<'a> {
277+
bpe: &'a BPE,
278+
marks: Vec<Mark>,
279+
i: usize,
280+
}
281+
270282
pub struct Iter<'a> {
271283
bpe: &'a BPE,
272-
slice: &'a [Mark],
284+
marks: &'a [Mark],
273285
}
274286

275287
impl BPE {
@@ -450,7 +462,34 @@ mod algorithm {
450462
pub fn iter(&self) -> Iter {
451463
Iter {
452464
bpe: self.bpe,
453-
slice: &self.marks,
465+
marks: &self.marks,
466+
}
467+
}
468+
}
469+
470+
impl<'a> IntoIterator for BpeTokenizer<'a> {
471+
type Item = utok;
472+
type IntoIter = IntoIter<'a>;
473+
#[inline]
474+
fn into_iter(self) -> Self::IntoIter {
475+
Self::IntoIter {
476+
bpe: self.bpe,
477+
marks: self.marks,
478+
i: 0,
479+
}
480+
}
481+
}
482+
483+
impl Iterator for IntoIter<'_> {
484+
type Item = utok;
485+
486+
fn next(&mut self) -> Option<Self::Item> {
487+
match &self.marks[self.i..] {
488+
&[Mark { token, .. }, ..] => {
489+
self.i += self.bpe.token(token).len();
490+
Some(token)
491+
}
492+
[] => None,
454493
}
455494
}
456495
}
@@ -459,9 +498,9 @@ mod algorithm {
459498
type Item = utok;
460499

461500
fn next(&mut self) -> Option<Self::Item> {
462-
match self.slice {
501+
match self.marks {
463502
&[Mark { token, .. }, ref tail @ ..] => {
464-
self.slice = &tail[self.bpe.token(token).len() - 1..];
503+
self.marks = &tail[self.bpe.token(token).len() - 1..];
465504
Some(token)
466505
}
467506
[] => None,

tokenizer/src/lib.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,38 @@
22

33
mod bpe;
44
mod normalizer;
5+
mod special;
56
mod vocab_txt;
67

78
/// `utok` for token id.
89
#[allow(non_camel_case_types)]
910
pub type utok = u32;
1011

11-
pub trait Tokenizer {
12+
pub trait Tokenize {
1213
fn vocab_size(&self) -> usize;
1314
fn encode(&self, text: &str) -> Vec<utok>;
1415
fn decode(&self, token: utok) -> &str;
1516
}
1617

18+
pub trait Method {
19+
fn unk_token(&self) -> utok;
20+
fn vocab_size(&self) -> usize;
21+
fn internal_special(&self) -> impl IntoIterator<Item = (&str, utok)>;
22+
fn encode<'a>(&'a self, text: &'a str) -> impl IntoIterator<Item = utok> + 'a;
23+
fn decode(&self, token: utok) -> &[u8];
24+
}
25+
1726
pub use bpe::BPE;
1827
pub use normalizer::{BPECommonNormalizer, Normalizer};
28+
pub use special::Tokenizer;
1929
pub use vocab_txt::VocabTxt;
2030

2131
const fn as_byte_token(piece: &[u8]) -> Option<u8> {
2232
// 按结构分解并转换
2333
match piece {
2434
&[b'<', b'0', b'x', a, b, b'>'] if a.is_ascii_hexdigit() && b.is_ascii_hexdigit() => {
2535
// ascii 转数字
36+
#[inline(always)]
2637
const fn to_num(c: u8) -> u8 {
2738
match c {
2839
b'0'..=b'9' => c - b'0',

tokenizer/src/special.rs

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
use crate::{utok, Method};
2+
use regex::Regex;
3+
use std::collections::HashMap;
4+
5+
pub struct Tokenizer<M> {
6+
method: M,
7+
special: HashMap<String, Vec<utok>>,
8+
special_regex: regex::Regex,
9+
}
10+
11+
impl<M: Method> Tokenizer<M> {
12+
pub fn new(method: M) -> Self {
13+
let special = method
14+
.internal_special()
15+
.into_iter()
16+
.map(|(k, v)| (k.to_string(), vec![v]))
17+
.collect::<HashMap<_, _>>();
18+
let special_regex = build_pattern(special.keys());
19+
Self {
20+
method,
21+
special,
22+
special_regex,
23+
}
24+
}
25+
26+
pub fn extend_special(&mut self, patterns: impl IntoIterator<Item = (String, Vec<utok>)>) {
27+
use std::collections::hash_map::Entry::{Occupied, Vacant};
28+
let mut any = false;
29+
for (k, v) in patterns {
30+
match self.special.entry(k) {
31+
Occupied(entry) => {
32+
assert_eq!(entry.get(), &v);
33+
}
34+
Vacant(entry) => {
35+
entry.insert(v);
36+
any = true;
37+
}
38+
}
39+
}
40+
if any {
41+
self.special_regex = build_pattern(self.special.keys());
42+
}
43+
}
44+
45+
pub fn encode(&self, text: &str) -> Vec<utok> {
46+
let mut ans = Vec::new();
47+
let mut start = 0;
48+
for m in self.special_regex.find_iter(text) {
49+
ans.extend(self.method.encode(&text[start..m.start()]));
50+
ans.extend_from_slice(&self.special[m.as_str()]);
51+
start = m.end();
52+
}
53+
ans.extend(self.method.encode(&text[start..]));
54+
ans
55+
}
56+
57+
pub fn decode(&self, tokens: &[utok]) -> String {
58+
let mut ans = Vec::new();
59+
for &t in tokens {
60+
ans.extend_from_slice(self.method.decode(t));
61+
}
62+
String::from_utf8(ans).unwrap()
63+
}
64+
65+
pub fn internal(&self) -> &M {
66+
&self.method
67+
}
68+
}
69+
70+
fn build_pattern<'a, T: AsRef<str>>(text: impl IntoIterator<Item = T>) -> Regex {
71+
let mut pattern = String::new();
72+
let mut iter = text.into_iter();
73+
if let Some(p) = iter.next() {
74+
pattern.push_str(p.as_ref());
75+
}
76+
for p in iter {
77+
pattern.push('|');
78+
pattern.push_str(p.as_ref());
79+
}
80+
regex::Regex::new(&pattern).unwrap()
81+
}
82+
83+
impl crate::Tokenize for Tokenizer<crate::BPE> {
84+
#[inline]
85+
fn vocab_size(&self) -> usize {
86+
self.method.vocab_size()
87+
}
88+
#[inline]
89+
fn encode(&self, text: &str) -> Vec<utok> {
90+
self.encode(text)
91+
}
92+
#[inline]
93+
fn decode(&self, token: utok) -> &str {
94+
unsafe { std::str::from_utf8_unchecked(self.method.decode(token)) }
95+
}
96+
}

tokenizer/src/vocab_txt.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::{decode_with_ascii, utok, Tokenizer};
1+
use crate::{decode_with_ascii, utok, Tokenize};
22
use memmap2::Mmap;
33
use patricia_tree::PatriciaMap;
44
use std::{fs::File, io::Result, path::Path};
@@ -28,7 +28,7 @@ impl VocabTxt {
2828
}
2929
}
3030

31-
impl Tokenizer for VocabTxt {
31+
impl Tokenize for VocabTxt {
3232
fn vocab_size(&self) -> usize {
3333
self.words.len()
3434
}

0 commit comments

Comments
 (0)