File tree Expand file tree Collapse file tree 1 file changed +9
-6
lines changed Expand file tree Collapse file tree 1 file changed +9
-6
lines changed Original file line number Diff line number Diff line change @@ -152,11 +152,13 @@ def __init__(
152152 template : str ,
153153 eos_token : str ,
154154 bos_token : str ,
155+ add_generation_prompt : bool = True ,
155156 ):
156157 """A chat formatter that uses jinja2 templates to format the prompt."""
157158 self .template = template
158159 self .eos_token = eos_token
159160 self .bos_token = bos_token
161+ self .add_generation_prompt = add_generation_prompt
160162
161163 self ._environment = jinja2 .Environment (
162164 loader = jinja2 .BaseLoader (),
@@ -170,12 +172,13 @@ def __call__(
170172 messages : List [llama_types .ChatCompletionRequestMessage ],
171173 ** kwargs : Any ,
172174 ) -> ChatFormatterResponse :
173- messages = [
174- * messages ,
175- llama_types .ChatCompletionRequestAssistantMessage (
176- role = "assistant" , content = ""
177- ),
178- ]
175+ if self .add_generation_prompt :
176+ messages = [
177+ * messages ,
178+ llama_types .ChatCompletionRequestAssistantMessage (
179+ role = "assistant" , content = ""
180+ ),
181+ ]
179182 prompt = self ._environment .render (
180183 messages = messages , eos_token = self .eos_token , bos_token = self .bos_token
181184 )
You can’t perform that action at this time.
0 commit comments