@@ -73,13 +73,16 @@ def _map_roles(
7373
7474
7575def _format_llama2 (
76- system_message : str , messages : List [Tuple [str , Optional [str ]]], sep : str
76+ system_message : str , messages : List [Tuple [str , Optional [str ]]], sep : str , sep2 : str
7777) -> str :
7878 """Format the prompt with the llama2 style."""
79+ seps = [sep , sep2 ]
7980 ret = system_message + sep
80- for role , message in messages :
81- if message :
82- ret += role + message + " "
81+ for i , (role , message ) in enumerate (messages ):
82+ if system_message and i == 0 :
83+ ret += message + seps [i % 2 ]
84+ elif message :
85+ ret += role + message + " " + seps [i % 2 ]
8386 else :
8487 ret += role + " "
8588 return ret
@@ -324,19 +327,20 @@ def get_chat_format(name: str):
324327 )
325328
326329
330+ # see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py
331+ # system prompt is "embedded" in the first message
327332@register_chat_format ("llama-2" )
328333def format_llama2 (
329334 messages : List [llama_types .ChatCompletionRequestMessage ],
330335 ** kwargs : Any ,
331336) -> ChatFormatterResponse :
332- _system_template = "[INST] <<SYS>>\n {system_message}\n <</SYS>>\n \n "
333- _roles = dict (user = "[INST]" , assistant = "[/INST]" )
334- _sep = "\n \n "
335- system_message = _get_system_message (messages )
336- system_message = _system_template .format (system_message = system_message )
337+ _system_template = "<s>[INST] <<SYS>>\n {system_message}\n <</SYS>>"
338+ _roles = dict (user = "<s>[INST]" , assistant = "[/INST]" )
337339 _messages = _map_roles (messages , _roles )
338- _messages .append ((_roles ["assistant" ], None ))
339- _prompt = _format_llama2 (system_message , _messages , _sep )
340+ system_message = _get_system_message (messages )
341+ if system_message :
342+ system_message = _system_template .format (system_message = system_message )
343+ _prompt = _format_llama2 (system_message , _messages , " " , "</s>" ) + "[/INST]"
340344 return ChatFormatterResponse (prompt = _prompt )
341345
342346
0 commit comments