@@ -378,6 +378,7 @@ def __init__(
378378
379379 self .chat_format = chat_format
380380 self .chat_handler = chat_handler
381+ self ._chat_handlers : Dict [str , llama_chat_format .LlamaChatCompletionHandler ] = {}
381382
382383 self .draft_model = draft_model
383384
@@ -409,10 +410,33 @@ def __init__(
409410 if self .verbose :
410411 print (f"Model metadata: { self .metadata } " , file = sys .stderr )
411412
413+ eos_token_id = int (self .metadata .get ("tokenizer.ggml.eos_token_id" , self .token_eos ()))
414+ bos_token_id = int (self .metadata .get ("tokenizer.ggml.bos_token_id" , self .token_bos ()))
415+
416+ eos_token = self ._model .token_get_text (eos_token_id )
417+ bos_token = self ._model .token_get_text (bos_token_id )
418+
419+ # Unfortunately the llama.cpp API does not return metadata arrays, so we can't get template names from tokenizer.chat_templates
420+ template_choices = dict ((name [10 :], template ) for name , template in self .metadata .items () if name .startswith ("tokenizer.chat_template." ))
421+
422+ if "tokenizer.chat_template" in self .metadata :
423+ template_choices ["chat_template.default" ] = self .metadata ["tokenizer.chat_template" ]
424+
425+ if self .verbose and template_choices :
426+ print (f"Available chat formats from metadata: { ', ' .join (template_choices .keys ())} " , file = sys .stderr )
427+
428+ for name , template in template_choices .items ():
429+ self ._chat_handlers [name ] = llama_chat_format .Jinja2ChatFormatter (
430+ template = template ,
431+ eos_token = eos_token ,
432+ bos_token = bos_token ,
433+ stop_token_ids = [eos_token_id ],
434+ ).to_chat_handler ()
435+
412436 if (
413437 self .chat_format is None
414438 and self .chat_handler is None
415- and "tokenizer. chat_template" in self . metadata
439+ and "chat_template.default " in template_choices
416440 ):
417441 chat_format = llama_chat_format .guess_chat_format_from_gguf_metadata (
418442 self .metadata
@@ -423,30 +447,12 @@ def __init__(
423447 if self .verbose :
424448 print (f"Guessed chat format: { chat_format } " , file = sys .stderr )
425449 else :
426- template = self .metadata ["tokenizer.chat_template" ]
427- try :
428- eos_token_id = int (self .metadata ["tokenizer.ggml.eos_token_id" ])
429- except :
430- eos_token_id = self .token_eos ()
431- try :
432- bos_token_id = int (self .metadata ["tokenizer.ggml.bos_token_id" ])
433- except :
434- bos_token_id = self .token_bos ()
435-
436- eos_token = self ._model .token_get_text (eos_token_id )
437- bos_token = self ._model .token_get_text (bos_token_id )
438-
439450 if self .verbose :
440- print (f"Using gguf chat template: { template } " , file = sys .stderr )
451+ print (f"Using gguf chat template: { template_choices [ 'chat_template.default' ] } " , file = sys .stderr )
441452 print (f"Using chat eos_token: { eos_token } " , file = sys .stderr )
442453 print (f"Using chat bos_token: { bos_token } " , file = sys .stderr )
443454
444- self .chat_handler = llama_chat_format .Jinja2ChatFormatter (
445- template = template ,
446- eos_token = eos_token ,
447- bos_token = bos_token ,
448- stop_token_ids = [eos_token_id ],
449- ).to_chat_handler ()
455+ self .chat_format = "chat_template.default"
450456
451457 if self .chat_format is None and self .chat_handler is None :
452458 self .chat_format = "llama-2"
@@ -1719,7 +1725,7 @@ def create_chat_completion(
17191725 Returns:
17201726 Generated chat completion or a stream of chat completion chunks.
17211727 """
1722- handler = self .chat_handler or llama_chat_format .get_chat_completion_handler (
1728+ handler = self .chat_handler or self . _chat_handlers . get ( self . chat_format ) or llama_chat_format .get_chat_completion_handler (
17231729 self .chat_format
17241730 )
17251731 return handler (
0 commit comments