6565// set, will return a dict of tokenizer outputs instead.
6666// """
6767
68- use std:: { borrow:: Cow , collections:: HashMap , fs:: read_to_string, path:: Path } ;
68+ use std:: { borrow:: Cow , collections:: HashMap , fs:: read_to_string, ops :: { Deref , DerefMut } , path:: Path , str :: FromStr } ;
6969
7070use minijinja:: { context, Environment , Template } ;
7171use serde:: { Deserialize , Serialize } ;
72+ use tokenizers:: Encoding ;
7273
7374use crate :: error:: Error ;
7475
75- #[ derive( Serialize ) ]
76- #[ serde( untagged) ]
77- pub enum Content < T =String > {
78- Typed ( T ) ,
79- Map ( HashMap < String , String > ) ,
76+ /// Wrapper around [`tokenizers::Tokenizer`] and [`minijinja::Environment`]
77+ /// providing more utilities.
78+ pub struct Tokenizer < ' a > {
79+ inner : tokenizers:: Tokenizer ,
80+ env : Environment < ' a > ,
81+ }
82+
83+ impl < ' a > Tokenizer < ' a > {
84+ pub fn from_tokenizer ( tokenizer : tokenizers:: Tokenizer ) -> Self {
85+ let mut env = Environment :: new ( ) ;
86+ env. set_unknown_method_callback ( minijinja_contrib:: pycompat:: unknown_method_callback) ;
87+ Self { inner : tokenizer, env }
88+ }
89+
90+ pub fn from_file ( file : impl AsRef < Path > ) -> tokenizers:: Result < Self > {
91+ tokenizers:: Tokenizer :: from_file ( file)
92+ . map ( Self :: from_tokenizer)
93+ }
94+
95+ pub fn from_bytes ( bytes : impl AsRef < [ u8 ] > ) -> tokenizers:: Result < Self > {
96+ tokenizers:: Tokenizer :: from_bytes ( bytes)
97+ . map ( Self :: from_tokenizer)
98+ }
99+
100+ pub fn from_str ( s : & str ) -> tokenizers:: Result < Self > {
101+ tokenizers:: Tokenizer :: from_str ( s)
102+ . map ( Self :: from_tokenizer)
103+ }
104+
105+ pub fn apply_chat_template < R , T > (
106+ & ' a mut self ,
107+ model_template : impl Into < Cow < ' a , str > > ,
108+ args : ApplyChatTemplateArgs < ' a , R , T >
109+ ) -> Result < String , Error >
110+ where
111+ R : Serialize ,
112+ T : Serialize ,
113+ {
114+ apply_chat_template ( & mut self . env , model_template, args)
115+ }
116+
117+ pub fn apply_chat_template_and_encode < R , T > (
118+ & ' a mut self ,
119+ model_template : impl Into < String > ,
120+ args : ApplyChatTemplateArgs < ' a , R , T > ,
121+ tokenize_options : TokenizeOptions ,
122+ ) -> Result < Encoding , Error > {
123+ todo ! ( )
124+ }
125+ }
126+
127+ impl Deref for Tokenizer < ' _ > {
128+ type Target = tokenizers:: Tokenizer ;
129+
130+ fn deref ( & self ) -> & Self :: Target {
131+ & self . inner
132+ }
133+ }
134+
135+ impl DerefMut for Tokenizer < ' _ > {
136+ fn deref_mut ( & mut self ) -> & mut Self :: Target {
137+ & mut self . inner
138+ }
80139}
81140
141+
82142#[ derive( Debug , Serialize ) ]
83143#[ serde( rename_all = "lowercase" ) ]
84144pub enum Role {
85145 User ,
86146 Assistant ,
87147}
88148
149+ #[ derive( Debug , Serialize ) ]
150+ pub enum Content {
151+ String ( String ) ,
152+ Map ( HashMap < String , String > )
153+ }
154+
89155#[ derive( Serialize ) ]
90- pub struct Conversation < R = Role , T = String > {
156+ pub struct Conversation < R , T > {
91157 pub role : R ,
92- pub content : Content < T > ,
158+ pub content : T ,
93159}
94160
95161pub type Documents = HashMap < String , String > ;
@@ -104,8 +170,8 @@ pub enum Truncation {
104170}
105171
106172#[ derive( Default ) ]
107- pub struct ApplyChatTemplateArgs < ' a > {
108- pub conversations : & ' a [ Conversation ] ,
173+ pub struct ApplyChatTemplateArgs < ' a , R = Role , T = String > {
174+ pub conversations : & ' a [ Conversation < R , T > ] ,
109175 pub tools : Option < Box < dyn FnOnce ( ) > > , // TODO: how to get response?
110176 pub documents : Option < & ' a [ Documents ] > ,
111177 pub model_id : & ' a str ,
@@ -306,11 +372,15 @@ pub fn load_model_chat_template_from_file(file: impl AsRef<Path>) -> std::io::Re
306372
307373// return rendered, all_generation_indices
308374
309- pub fn apply_chat_template < ' a > (
375+ pub fn apply_chat_template < ' a , R , T > (
310376 env : & ' a mut Environment < ' a > ,
311- model_template : & ' a str ,
312- args : ApplyChatTemplateArgs < ' a > ,
313- ) -> Result < String , Error > {
377+ model_template : impl Into < Cow < ' a , str > > ,
378+ args : ApplyChatTemplateArgs < ' a , R , T > ,
379+ ) -> Result < String , Error >
380+ where
381+ R : Serialize ,
382+ T : Serialize ,
383+ {
314384 let ApplyChatTemplateArgs {
315385 conversations,
316386 tools,
@@ -329,7 +399,7 @@ pub fn apply_chat_template<'a>(
329399 None => match env. get_template ( model_id) {
330400 Ok ( template) => template,
331401 Err ( _) => {
332- env. add_template_owned ( model_id, model_template. to_owned ( ) ) ?;
402+ env. add_template_owned ( model_id, model_template) ?;
333403 env. get_template ( model_id)
334404 . expect ( "Newly added template must be present" )
335405 }
@@ -392,7 +462,7 @@ mod tests {
392462 let conversations = vec ! [
393463 Conversation {
394464 role: Role :: User ,
395- content: crate :: tokenizer :: Content :: Typed ( "hello" . to_string ( ) )
465+ content: "hello" ,
396466 }
397467 ] ;
398468
@@ -417,4 +487,44 @@ mod tests {
417487 let rendered_chat = apply_chat_template ( & mut env, & model_chat_template, args) . unwrap ( ) ;
418488 println ! ( "{:?}" , rendered_chat) ;
419489 }
490+
491+ #[ test]
492+ fn test_tokenizer_apply_chat_template ( ) {
493+ let hf_cache_dir = PathBuf :: from ( "./hf_cache" ) ;
494+
495+ let api = ApiBuilder :: new ( )
496+ . with_endpoint ( "https://hf-mirror.com" . to_string ( ) ) // comment out this line if your area is not banned
497+ . with_cache_dir ( hf_cache_dir)
498+ . build ( ) . unwrap ( ) ;
499+ let model_id = "mlx-community/Qwen3-4B-bf16" . to_string ( ) ;
500+
501+ let conversations = vec ! [
502+ Conversation {
503+ role: Role :: User ,
504+ content: "hello" ,
505+ }
506+ ] ;
507+
508+ let repo = api. repo ( Repo :: new ( model_id. clone ( ) , hf_hub:: RepoType :: Model ) ) ;
509+ let tokenizer_file = repo. get ( "tokenizer.json" ) . unwrap ( ) ;
510+ let tokenizer_config_file = repo. get ( "tokenizer_config.json" ) . unwrap ( ) ;
511+
512+ let mut tokenizer = super :: Tokenizer :: from_file ( tokenizer_file) . unwrap ( ) ;
513+
514+ let model_chat_template = load_model_chat_template_from_file ( tokenizer_config_file) . unwrap ( ) . unwrap ( ) ;
515+ assert ! ( !model_chat_template. is_empty( ) ) ;
516+
517+ let args = ApplyChatTemplateArgs {
518+ conversations : & conversations,
519+ tools : None ,
520+ documents : None ,
521+ model_id : & model_id,
522+ chat_template_id : None ,
523+ add_generation_prompt : None ,
524+ continue_final_message : None ,
525+ } ;
526+
527+ let rendered_chat = tokenizer. apply_chat_template ( & model_chat_template, args) . unwrap ( ) ;
528+ println ! ( "{:?}" , rendered_chat) ;
529+ }
420530}
0 commit comments