Skip to content

Commit fccdaa8

Browse files
committed
added wrapper Tokenizer
1 parent 1afcaaf commit fccdaa8

File tree

1 file changed

+126
-16
lines changed

1 file changed

+126
-16
lines changed

mlx-lm-utils/src/tokenizer.rs

Lines changed: 126 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -65,31 +65,97 @@
6565
// set, will return a dict of tokenizer outputs instead.
6666
// """
6767

68-
use std::{borrow::Cow, collections::HashMap, fs::read_to_string, path::Path};
68+
use std::{borrow::Cow, collections::HashMap, fs::read_to_string, ops::{Deref, DerefMut}, path::Path, str::FromStr};
6969

7070
use minijinja::{context, Environment, Template};
7171
use serde::{Deserialize, Serialize};
72+
use tokenizers::Encoding;
7273

7374
use crate::error::Error;
7475

75-
#[derive(Serialize)]
76-
#[serde(untagged)]
77-
pub enum Content<T=String> {
78-
Typed(T),
79-
Map(HashMap<String, String>),
76+
/// Wrapper around [`tokenizers::Tokenizer`] and [`minijinja::Environment`]
77+
/// providing more utilities.
78+
pub struct Tokenizer<'a> {
79+
inner: tokenizers::Tokenizer,
80+
env: Environment<'a>,
81+
}
82+
83+
impl<'a> Tokenizer<'a> {
84+
pub fn from_tokenizer(tokenizer: tokenizers::Tokenizer) -> Self {
85+
let mut env = Environment::new();
86+
env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
87+
Self { inner: tokenizer, env }
88+
}
89+
90+
pub fn from_file(file: impl AsRef<Path>) -> tokenizers::Result<Self> {
91+
tokenizers::Tokenizer::from_file(file)
92+
.map(Self::from_tokenizer)
93+
}
94+
95+
pub fn from_bytes(bytes: impl AsRef<[u8]>) -> tokenizers::Result<Self> {
96+
tokenizers::Tokenizer::from_bytes(bytes)
97+
.map(Self::from_tokenizer)
98+
}
99+
100+
pub fn from_str(s: &str) -> tokenizers::Result<Self> {
101+
tokenizers::Tokenizer::from_str(s)
102+
.map(Self::from_tokenizer)
103+
}
104+
105+
pub fn apply_chat_template<R, T>(
106+
&'a mut self,
107+
model_template: impl Into<Cow<'a, str>>,
108+
args: ApplyChatTemplateArgs<'a, R, T>
109+
) -> Result<String, Error>
110+
where
111+
R: Serialize,
112+
T: Serialize,
113+
{
114+
apply_chat_template(&mut self.env, model_template, args)
115+
}
116+
117+
pub fn apply_chat_template_and_encode<R, T>(
118+
&'a mut self,
119+
model_template: impl Into<String>,
120+
args: ApplyChatTemplateArgs<'a, R, T>,
121+
tokenize_options: TokenizeOptions,
122+
) -> Result<Encoding, Error> {
123+
todo!()
124+
}
125+
}
126+
127+
impl Deref for Tokenizer<'_> {
128+
type Target = tokenizers::Tokenizer;
129+
130+
fn deref(&self) -> &Self::Target {
131+
&self.inner
132+
}
133+
}
134+
135+
impl DerefMut for Tokenizer<'_> {
136+
fn deref_mut(&mut self) -> &mut Self::Target {
137+
&mut self.inner
138+
}
80139
}
81140

141+
82142
#[derive(Debug, Serialize)]
83143
#[serde(rename_all = "lowercase")]
84144
pub enum Role {
85145
User,
86146
Assistant,
87147
}
88148

149+
#[derive(Debug, Serialize)]
150+
pub enum Content {
151+
String(String),
152+
Map(HashMap<String, String>)
153+
}
154+
89155
#[derive(Serialize)]
90-
pub struct Conversation<R=Role, T=String> {
156+
pub struct Conversation<R, T> {
91157
pub role: R,
92-
pub content: Content<T>,
158+
pub content: T,
93159
}
94160

95161
pub type Documents = HashMap<String, String>;
@@ -104,8 +170,8 @@ pub enum Truncation {
104170
}
105171

106172
#[derive(Default)]
107-
pub struct ApplyChatTemplateArgs<'a> {
108-
pub conversations: &'a [Conversation],
173+
pub struct ApplyChatTemplateArgs<'a, R=Role, T=String> {
174+
pub conversations: &'a [Conversation<R, T>],
109175
pub tools: Option<Box<dyn FnOnce()>>, // TODO: how to get response?
110176
pub documents: Option<&'a [Documents]>,
111177
pub model_id: &'a str,
@@ -306,11 +372,15 @@ pub fn load_model_chat_template_from_file(file: impl AsRef<Path>) -> std::io::Re
306372

307373
// return rendered, all_generation_indices
308374

309-
pub fn apply_chat_template<'a>(
375+
pub fn apply_chat_template<'a, R, T>(
310376
env: &'a mut Environment<'a>,
311-
model_template: &'a str,
312-
args: ApplyChatTemplateArgs<'a>,
313-
) -> Result<String, Error> {
377+
model_template: impl Into<Cow<'a, str>>,
378+
args: ApplyChatTemplateArgs<'a, R, T>,
379+
) -> Result<String, Error>
380+
where
381+
R: Serialize,
382+
T: Serialize,
383+
{
314384
let ApplyChatTemplateArgs {
315385
conversations,
316386
tools,
@@ -329,7 +399,7 @@ pub fn apply_chat_template<'a>(
329399
None => match env.get_template(model_id) {
330400
Ok(template) => template,
331401
Err(_) => {
332-
env.add_template_owned(model_id, model_template.to_owned())?;
402+
env.add_template_owned(model_id, model_template)?;
333403
env.get_template(model_id)
334404
.expect("Newly added template must be present")
335405
}
@@ -392,7 +462,7 @@ mod tests {
392462
let conversations = vec![
393463
Conversation {
394464
role: Role::User,
395-
content: crate::tokenizer::Content::Typed("hello".to_string())
465+
content: "hello",
396466
}
397467
];
398468

@@ -417,4 +487,44 @@ mod tests {
417487
let rendered_chat = apply_chat_template(&mut env, &model_chat_template, args).unwrap();
418488
println!("{:?}", rendered_chat);
419489
}
490+
491+
#[test]
492+
fn test_tokenizer_apply_chat_template() {
493+
let hf_cache_dir = PathBuf::from("./hf_cache");
494+
495+
let api = ApiBuilder::new()
496+
.with_endpoint("https://hf-mirror.com".to_string()) // comment out this line if your area is not banned
497+
.with_cache_dir(hf_cache_dir)
498+
.build().unwrap();
499+
let model_id = "mlx-community/Qwen3-4B-bf16".to_string();
500+
501+
let conversations = vec![
502+
Conversation {
503+
role: Role::User,
504+
content: "hello",
505+
}
506+
];
507+
508+
let repo = api.repo(Repo::new(model_id.clone(), hf_hub::RepoType::Model));
509+
let tokenizer_file = repo.get("tokenizer.json").unwrap();
510+
let tokenizer_config_file = repo.get("tokenizer_config.json").unwrap();
511+
512+
let mut tokenizer = super::Tokenizer::from_file(tokenizer_file).unwrap();
513+
514+
let model_chat_template = load_model_chat_template_from_file(tokenizer_config_file).unwrap().unwrap();
515+
assert!(!model_chat_template.is_empty());
516+
517+
let args = ApplyChatTemplateArgs {
518+
conversations: &conversations,
519+
tools: None,
520+
documents: None,
521+
model_id: &model_id,
522+
chat_template_id: None,
523+
add_generation_prompt: None,
524+
continue_final_message: None,
525+
};
526+
527+
let rendered_chat = tokenizer.apply_chat_template(&model_chat_template, args).unwrap();
528+
println!("{:?}", rendered_chat);
529+
}
420530
}

0 commit comments

Comments
 (0)