6565// set, will return a dict of tokenizer outputs instead.
6666// """
6767
68- use std:: { borrow:: Cow , collections:: HashMap } ;
68+ use std:: { borrow:: Cow , collections:: HashMap , fs :: read_to_string , path :: Path } ;
6969
7070use minijinja:: { context, Environment , Template } ;
7171use serde:: { Deserialize , Serialize } ;
@@ -74,7 +74,7 @@ use crate::error::Error;
7474
7575#[ derive( Serialize ) ]
7676#[ serde( untagged) ]
77- pub enum Content < T : Serialize = ( ) > {
77+ pub enum Content < T : Serialize = ( ) > {
7878 String ( String ) ,
7979 Map ( HashMap < String , String > ) ,
8080 Typed ( T ) ,
@@ -115,6 +115,21 @@ pub struct TokenizeOptions {
115115 pub return_assistant_tokens_mask : Option < bool > ,
116116}
117117
118+ pub fn load_chat_template_from_str ( content : & str ) -> std:: io:: Result < Option < String > > {
119+ serde_json:: from_str :: < serde_json:: Value > ( content) . map ( |value| {
120+ value
121+ . get ( "chat_template" )
122+ . and_then ( |value| value. as_str ( ) )
123+ . map ( ToString :: to_string)
124+ } )
125+ . map_err ( Into :: into)
126+ }
127+
128+ pub fn load_chat_template_from_file ( file : impl AsRef < Path > ) -> std:: io:: Result < Option < String > > {
129+ let content = read_to_string ( file) ?;
130+ load_chat_template_from_str ( & content)
131+ }
132+
118133// chat_template = self.get_chat_template(chat_template, tools)
119134
120135// if isinstance(conversation, (list, tuple)) and (
@@ -192,7 +207,6 @@ pub struct TokenizeOptions {
192207// else:
193208// return rendered_chat
194209
195-
196210// def render_jinja_template(
197211// conversations: list[list[dict[str, str]]],
198212// tools: Optional[list[Union[dict, Callable]]] = None,
@@ -286,7 +300,6 @@ pub struct TokenizeOptions {
286300
287301// return rendered, all_generation_indices
288302
289-
290303pub fn apply_chat_template < ' a > (
291304 env : & ' a mut Environment < ' a > ,
292305 model_template : & ' a str ,
@@ -311,19 +324,20 @@ pub fn apply_chat_template<'a>(
311324 Ok ( template) => template,
312325 Err ( _) => {
313326 env. add_template ( model_id, model_template) ?;
314- env. get_template ( model_id) . expect ( "Newly added template must be present" )
315- } ,
327+ env. get_template ( model_id)
328+ . expect ( "Newly added template must be present" )
329+ }
316330 } ,
317331 } ;
318332
319333 // TODO: what about list of list of conversations
320-
334+
321335 // TODO: handle tool
322336
323337 // TODO: handle documents``
324338
325339 // TODO: allow return_generation_indices
326-
340+
327341 let rendered_chat = template. render ( context ! {
328342 messages => conversations,
329343 documents => documents,
0 commit comments