1010
1111import jinja2
1212
13+ import numpy as np
14+ import numpy .typing as npt
15+
1316import llama_cpp .llama as llama
1417import llama_cpp .llama_types as llama_types
1518import llama_cpp .llama_grammar as llama_grammar
@@ -150,6 +153,7 @@ class ChatFormatterResponse:
150153
151154 prompt : str
152155 stop : Optional [Union [str , List [str ]]] = None
156+ stopping_criteria : Optional [llama .StoppingCriteriaList ] = None
153157
154158
155159class ChatFormatter (Protocol ):
@@ -173,12 +177,14 @@ def __init__(
173177 eos_token : str ,
174178 bos_token : str ,
175179 add_generation_prompt : bool = True ,
180+ stop_token_ids : Optional [List [int ]] = None ,
176181 ):
177182 """A chat formatter that uses jinja2 templates to format the prompt."""
178183 self .template = template
179184 self .eos_token = eos_token
180185 self .bos_token = bos_token
181186 self .add_generation_prompt = add_generation_prompt
187+ self .stop_token_ids = set (stop_token_ids ) if stop_token_ids is not None else None
182188
183189 self ._environment = jinja2 .Environment (
184190 loader = jinja2 .BaseLoader (),
@@ -211,7 +217,16 @@ def raise_exception(message: str):
211217 tool_choice = tool_choice ,
212218 )
213219
214- return ChatFormatterResponse (prompt = prompt , stop = [self .eos_token ])
220+ stopping_criteria = None
221+ if self .stop_token_ids is not None :
222+ def stop_on_last_token (
223+ tokens : npt .NDArray [np .intc ],
224+ logits : npt .NDArray [np .single ]
225+ ) -> bool :
226+ return tokens [- 1 ] in self .stop_token_ids
227+ stopping_criteria = llama .StoppingCriteriaList ([stop_on_last_token ])
228+
229+ return ChatFormatterResponse (prompt = prompt , stop = [self .eos_token ], stopping_criteria = stopping_criteria )
215230
216231 def to_chat_handler (self ) -> LlamaChatCompletionHandler :
217232 return chat_formatter_to_chat_completion_handler (self )
@@ -533,6 +548,10 @@ def chat_completion_handler(
533548 rstop = result .stop if isinstance (result .stop , list ) else [result .stop ]
534549 stop = stop + rstop
535550
551+ stopping_criteria = None
552+ if result .stopping_criteria is not None :
553+ stopping_criteria = result .stopping_criteria
554+
536555 if response_format is not None and response_format ["type" ] == "json_object" :
537556 grammar = _grammar_for_response_format (response_format , verbose = llama .verbose )
538557
@@ -598,6 +617,7 @@ def chat_completion_handler(
598617 mirostat_eta = mirostat_eta ,
599618 model = model ,
600619 logits_processor = logits_processor ,
620+ stopping_criteria = stopping_criteria ,
601621 grammar = grammar ,
602622 logit_bias = logit_bias ,
603623 )
0 commit comments