Skip to content

Commit 1afcaaf

Browse files
committed
tested render jinja template
1 parent d47ed3e commit 1afcaaf

File tree

3 files changed

+89
-14
lines changed

3 files changed

+89
-14
lines changed

mlx-lm-utils/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
/hf_cache/

mlx-lm-utils/Cargo.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@ documentation.workspace = true
1111

1212
[dependencies]
1313
minijinja = { version = "2", features = ["loader"] }
14+
minijinja-contrib = { version = "2", features = ["pycompat"] }
1415
serde = { version = "1", features = ["derive"] }
1516
serde_json = "1"
1617
thiserror = "2"
17-
tokenizers = "0.21"
18+
tokenizers = "0.21"
19+
20+
[dev-dependencies]
21+
hf-hub = "=0.4.1" # 0.4.2 uses features that went stable in 1.82 while our MSRV is 1.81

mlx-lm-utils/src/tokenizer.rs

Lines changed: 83 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,21 @@ use crate::error::Error;
7474

7575
#[derive(Serialize)]
7676
#[serde(untagged)]
77-
pub enum Content<T: Serialize = ()> {
78-
String(String),
79-
Map(HashMap<String, String>),
77+
pub enum Content<T=String> {
8078
Typed(T),
79+
Map(HashMap<String, String>),
80+
}
81+
82+
#[derive(Debug, Serialize)]
83+
#[serde(rename_all = "lowercase")]
84+
pub enum Role {
85+
User,
86+
Assistant,
8187
}
8288

8389
#[derive(Serialize)]
84-
pub struct Conversation<T: Serialize = ()> {
85-
pub role: String,
90+
pub struct Conversation<R=Role, T=String> {
91+
pub role: R,
8692
pub content: Content<T>,
8793
}
8894

@@ -103,7 +109,7 @@ pub struct ApplyChatTemplateArgs<'a> {
103109
pub tools: Option<Box<dyn FnOnce()>>, // TODO: how to get response?
104110
pub documents: Option<&'a [Documents]>,
105111
pub model_id: &'a str,
106-
pub chat_template: Option<&'a str>,
112+
pub chat_template_id: Option<&'a str>,
107113
pub add_generation_prompt: Option<bool>,
108114
pub continue_final_message: Option<bool>,
109115
}
@@ -115,7 +121,7 @@ pub struct TokenizeOptions {
115121
pub return_assistant_tokens_mask: Option<bool>,
116122
}
117123

118-
pub fn load_chat_template_from_str(content: &str) -> std::io::Result<Option<String>> {
124+
pub fn load_model_chat_template_from_str(content: &str) -> std::io::Result<Option<String>> {
119125
serde_json::from_str::<serde_json::Value>(content).map(|value| {
120126
value
121127
.get("chat_template")
@@ -125,9 +131,9 @@ pub fn load_chat_template_from_str(content: &str) -> std::io::Result<Option<Stri
125131
.map_err(Into::into)
126132
}
127133

128-
pub fn load_chat_template_from_file(file: impl AsRef<Path>) -> std::io::Result<Option<String>> {
134+
pub fn load_model_chat_template_from_file(file: impl AsRef<Path>) -> std::io::Result<Option<String>> {
129135
let content = read_to_string(file)?;
130-
load_chat_template_from_str(&content)
136+
load_model_chat_template_from_str(&content)
131137
}
132138

133139
// chat_template = self.get_chat_template(chat_template, tools)
@@ -310,20 +316,20 @@ pub fn apply_chat_template<'a>(
310316
tools,
311317
documents,
312318
model_id,
313-
chat_template,
319+
chat_template_id,
314320
add_generation_prompt,
315321
continue_final_message,
316322
} = args;
317323

318324
let add_generation_prompt = add_generation_prompt.unwrap_or(false);
319325
let continue_final_message = continue_final_message.unwrap_or(false);
320326

321-
let template = match chat_template {
322-
Some(chat_template) => env.get_template(&chat_template)?,
327+
let template = match chat_template_id {
328+
Some(chat_template_id) => env.get_template(&chat_template_id)?,
323329
None => match env.get_template(model_id) {
324330
Ok(template) => template,
325331
Err(_) => {
326-
env.add_template(model_id, model_template)?;
332+
env.add_template_owned(model_id, model_template.to_owned())?;
327333
env.get_template(model_id)
328334
.expect("Newly added template must be present")
329335
}
@@ -348,3 +354,67 @@ pub fn apply_chat_template<'a>(
348354

349355
Ok(rendered_chat)
350356
}
357+
358+
#[cfg(test)]
359+
mod tests {
360+
use std::path::PathBuf;
361+
362+
use hf_hub::{api::sync::ApiBuilder, Repo};
363+
use minijinja::Environment;
364+
365+
use crate::tokenizer::{apply_chat_template, load_model_chat_template_from_file, ApplyChatTemplateArgs, Conversation, Role};
366+
367+
#[test]
368+
fn test_load_chat_template_from_file() {
369+
let hf_cache_dir = PathBuf::from("./hf_cache");
370+
371+
let api = ApiBuilder::new()
372+
.with_endpoint("https://hf-mirror.com".to_string()) // comment out this line if your area is not banned
373+
.with_cache_dir(hf_cache_dir)
374+
.build().unwrap();
375+
let model_id = "mlx-community/Qwen3-4B-bf16".to_string();
376+
let repo = api.repo(Repo::new(model_id, hf_hub::RepoType::Model));
377+
let file = repo.get("tokenizer_config.json").unwrap();
378+
let chat_template = load_model_chat_template_from_file(file).unwrap().unwrap();
379+
assert!(!chat_template.is_empty());
380+
}
381+
382+
#[test]
383+
fn test_apply_chat_template() {
384+
let hf_cache_dir = PathBuf::from("./hf_cache");
385+
386+
let api = ApiBuilder::new()
387+
.with_endpoint("https://hf-mirror.com".to_string()) // comment out this line if your area is not banned
388+
.with_cache_dir(hf_cache_dir)
389+
.build().unwrap();
390+
let model_id = "mlx-community/Qwen3-4B-bf16".to_string();
391+
392+
let conversations = vec![
393+
Conversation {
394+
role: Role::User,
395+
content: crate::tokenizer::Content::Typed("hello".to_string())
396+
}
397+
];
398+
399+
let repo = api.repo(Repo::new(model_id.clone(), hf_hub::RepoType::Model));
400+
let file = repo.get("tokenizer_config.json").unwrap();
401+
let model_chat_template = load_model_chat_template_from_file(file).unwrap().unwrap();
402+
assert!(!model_chat_template.is_empty());
403+
404+
let args = ApplyChatTemplateArgs {
405+
conversations: &conversations,
406+
tools: None,
407+
documents: None,
408+
model_id: &model_id,
409+
chat_template_id: None,
410+
add_generation_prompt: None,
411+
continue_final_message: None,
412+
};
413+
414+
let mut env = Environment::new();
415+
env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
416+
417+
let rendered_chat = apply_chat_template(&mut env, &model_chat_template, args).unwrap();
418+
println!("{:?}", rendered_chat);
419+
}
420+
}

0 commit comments

Comments
 (0)