@@ -74,15 +74,21 @@ use crate::error::Error;
7474
7575#[ derive( Serialize ) ]
7676#[ serde( untagged) ]
77- pub enum Content < T : Serialize = ( ) > {
78- String ( String ) ,
79- Map ( HashMap < String , String > ) ,
77+ pub enum Content < T =String > {
8078 Typed ( T ) ,
79+ Map ( HashMap < String , String > ) ,
80+ }
81+
82+ #[ derive( Debug , Serialize ) ]
83+ #[ serde( rename_all = "lowercase" ) ]
84+ pub enum Role {
85+ User ,
86+ Assistant ,
8187}
8288
8389#[ derive( Serialize ) ]
84- pub struct Conversation < T : Serialize = ( ) > {
85- pub role : String ,
90+ pub struct Conversation < R = Role , T = String > {
91+ pub role : R ,
8692 pub content : Content < T > ,
8793}
8894
@@ -103,7 +109,7 @@ pub struct ApplyChatTemplateArgs<'a> {
103109 pub tools : Option < Box < dyn FnOnce ( ) > > , // TODO: how to get response?
104110 pub documents : Option < & ' a [ Documents ] > ,
105111 pub model_id : & ' a str ,
106- pub chat_template : Option < & ' a str > ,
112+ pub chat_template_id : Option < & ' a str > ,
107113 pub add_generation_prompt : Option < bool > ,
108114 pub continue_final_message : Option < bool > ,
109115}
@@ -115,7 +121,7 @@ pub struct TokenizeOptions {
115121 pub return_assistant_tokens_mask : Option < bool > ,
116122}
117123
118- pub fn load_chat_template_from_str ( content : & str ) -> std:: io:: Result < Option < String > > {
124+ pub fn load_model_chat_template_from_str ( content : & str ) -> std:: io:: Result < Option < String > > {
119125 serde_json:: from_str :: < serde_json:: Value > ( content) . map ( |value| {
120126 value
121127 . get ( "chat_template" )
@@ -125,9 +131,9 @@ pub fn load_chat_template_from_str(content: &str) -> std::io::Result<Option<Stri
125131 . map_err ( Into :: into)
126132}
127133
128- pub fn load_chat_template_from_file ( file : impl AsRef < Path > ) -> std:: io:: Result < Option < String > > {
134+ pub fn load_model_chat_template_from_file ( file : impl AsRef < Path > ) -> std:: io:: Result < Option < String > > {
129135 let content = read_to_string ( file) ?;
130- load_chat_template_from_str ( & content)
136+ load_model_chat_template_from_str ( & content)
131137}
132138
133139// chat_template = self.get_chat_template(chat_template, tools)
@@ -310,20 +316,20 @@ pub fn apply_chat_template<'a>(
310316 tools,
311317 documents,
312318 model_id,
313- chat_template ,
319+ chat_template_id ,
314320 add_generation_prompt,
315321 continue_final_message,
316322 } = args;
317323
318324 let add_generation_prompt = add_generation_prompt. unwrap_or ( false ) ;
319325 let continue_final_message = continue_final_message. unwrap_or ( false ) ;
320326
321- let template = match chat_template {
322- Some ( chat_template ) => env. get_template ( & chat_template ) ?,
327+ let template = match chat_template_id {
328+ Some ( chat_template_id ) => env. get_template ( & chat_template_id ) ?,
323329 None => match env. get_template ( model_id) {
324330 Ok ( template) => template,
325331 Err ( _) => {
326- env. add_template ( model_id, model_template) ?;
332+ env. add_template_owned ( model_id, model_template. to_owned ( ) ) ?;
327333 env. get_template ( model_id)
328334 . expect ( "Newly added template must be present" )
329335 }
@@ -348,3 +354,67 @@ pub fn apply_chat_template<'a>(
348354
349355 Ok ( rendered_chat)
350356}
357+
358+ #[ cfg( test) ]
359+ mod tests {
360+ use std:: path:: PathBuf ;
361+
362+ use hf_hub:: { api:: sync:: ApiBuilder , Repo } ;
363+ use minijinja:: Environment ;
364+
365+ use crate :: tokenizer:: { apply_chat_template, load_model_chat_template_from_file, ApplyChatTemplateArgs , Conversation , Role } ;
366+
367+ #[ test]
368+ fn test_load_chat_template_from_file ( ) {
369+ let hf_cache_dir = PathBuf :: from ( "./hf_cache" ) ;
370+
371+ let api = ApiBuilder :: new ( )
372+ . with_endpoint ( "https://hf-mirror.com" . to_string ( ) ) // comment out this line if your area is not banned
373+ . with_cache_dir ( hf_cache_dir)
374+ . build ( ) . unwrap ( ) ;
375+ let model_id = "mlx-community/Qwen3-4B-bf16" . to_string ( ) ;
376+ let repo = api. repo ( Repo :: new ( model_id, hf_hub:: RepoType :: Model ) ) ;
377+ let file = repo. get ( "tokenizer_config.json" ) . unwrap ( ) ;
378+ let chat_template = load_model_chat_template_from_file ( file) . unwrap ( ) . unwrap ( ) ;
379+ assert ! ( !chat_template. is_empty( ) ) ;
380+ }
381+
382+ #[ test]
383+ fn test_apply_chat_template ( ) {
384+ let hf_cache_dir = PathBuf :: from ( "./hf_cache" ) ;
385+
386+ let api = ApiBuilder :: new ( )
387+ . with_endpoint ( "https://hf-mirror.com" . to_string ( ) ) // comment out this line if your area is not banned
388+ . with_cache_dir ( hf_cache_dir)
389+ . build ( ) . unwrap ( ) ;
390+ let model_id = "mlx-community/Qwen3-4B-bf16" . to_string ( ) ;
391+
392+ let conversations = vec ! [
393+ Conversation {
394+ role: Role :: User ,
395+ content: crate :: tokenizer:: Content :: Typed ( "hello" . to_string( ) )
396+ }
397+ ] ;
398+
399+ let repo = api. repo ( Repo :: new ( model_id. clone ( ) , hf_hub:: RepoType :: Model ) ) ;
400+ let file = repo. get ( "tokenizer_config.json" ) . unwrap ( ) ;
401+ let model_chat_template = load_model_chat_template_from_file ( file) . unwrap ( ) . unwrap ( ) ;
402+ assert ! ( !model_chat_template. is_empty( ) ) ;
403+
404+ let args = ApplyChatTemplateArgs {
405+ conversations : & conversations,
406+ tools : None ,
407+ documents : None ,
408+ model_id : & model_id,
409+ chat_template_id : None ,
410+ add_generation_prompt : None ,
411+ continue_final_message : None ,
412+ } ;
413+
414+ let mut env = Environment :: new ( ) ;
415+ env. set_unknown_method_callback ( minijinja_contrib:: pycompat:: unknown_method_callback) ;
416+
417+ let rendered_chat = apply_chat_template ( & mut env, & model_chat_template, args) . unwrap ( ) ;
418+ println ! ( "{:?}" , rendered_chat) ;
419+ }
420+ }
0 commit comments