@@ -42,6 +42,7 @@ def __init__(
4242 * ,
4343 tokenizer_config_json_path : str | None = None ,
4444 generation_config_path : str | None = None ,
45+ chat_template_path : str | None = None ,
4546 ):
4647 self .tokenizer = Tokenizer .from_file (tokenizer_json_path )
4748 if not (tokenizer_config_json_path or generation_config_path ):
@@ -51,6 +52,10 @@ def __init__(
5152 if tokenizer_config_json_path :
5253 with open (tokenizer_config_json_path , "rb" ) as f :
5354 self .config = json .load (f )
55+ if chat_template_path :
56+ with open (chat_template_path , "r" ) as f :
57+ # TODO: warning in the case of overwrite?
58+ self .config ["chat_template" ] = f .read ()
5459 else :
5560 self .config = None
5661 if generation_config_path :
@@ -227,12 +232,14 @@ def __init__(
227232 * ,
228233 tokenizer_config_json_path : str | None = None ,
229234 generation_config_path : str | None = None ,
235+ chat_template_path : str | None = None ,
230236 truncation_type : str = "right" ,
231237 ):
232238 self .base_tokenizer = HuggingFaceBaseTokenizer (
233239 tokenizer_json_path = tokenizer_json_path ,
234240 tokenizer_config_json_path = tokenizer_config_json_path ,
235241 generation_config_path = generation_config_path ,
242+ chat_template_path = chat_template_path
236243 )
237244
238245 # Contents of the tokenizer_config.json
0 commit comments