@@ -121,14 +121,21 @@ def decorator(f: LlamaChatCompletionHandler):
121121
122122@dataclasses .dataclass
123123class ChatFormatterResponse :
124+ """Dataclass that stores completion parameters for a given chat format and
125+ create_chat_completion request.
126+
127+ prompt contains the formatted prompt generated from the chat format and messages.
128+ stop contains the stop token or list of stop tokens to use for the chat format."""
129+
124130 prompt : str
125131 stop : Optional [Union [str , List [str ]]] = None
126132
127133
128134class ChatFormatter (Protocol ):
129135 """Base Protocol for a chat formatter. A chat formatter is a function that
130- takes a list of messages and returns a formatted prompt. It can also return
131- a stop token or list of stop tokens to use for the completion."""
136+ takes a list of messages and returns a chat format response which can be used
137+ to generate a completion. The response can also include a stop token or list
138+ of stop tokens to use for the completion."""
132139
133140 def __call__ (
134141 self ,
@@ -139,131 +146,43 @@ def __call__(
139146 ...
140147
141148
142- ### Utility functions for formatting chat prompts ###
143-
144-
145- def _get_system_message (
146- messages : List [llama_types .ChatCompletionRequestMessage ],
147- ) -> str :
148- """Get the first system message."""
149- for message in messages :
150- if message ["role" ] == "system" :
151- return message ["content" ] or ""
152- return ""
153-
154-
155- def _map_roles (
156- messages : List [llama_types .ChatCompletionRequestMessage ],
157- role_map : Dict [str , str ],
158- ) -> List [Tuple [str , Optional [str ]]]:
159- """Map the message roles."""
160- output : List [Tuple [str , Optional [str ]]] = []
161- for message in messages :
162- role = message ["role" ]
163- if role in role_map :
164- content : str | None = (
165- message ["content" ] if isinstance (message ["content" ], str ) else None
166- )
167- output .append ((role_map [role ], content ))
168- return output
169-
170-
171- def _format_llama2 (
172- system_message : str , messages : List [Tuple [str , Optional [str ]]], sep : str , sep2 : str
173- ) -> str :
174- """Format the prompt with the llama2 style."""
175- seps = [sep , sep2 ]
176- ret = system_message + sep
177- for i , (role , message ) in enumerate (messages ):
178- if system_message and i == 0 :
179- m = message or ""
180- ret += m + seps [i % 2 ]
181- elif message :
182- ret += role + message + " " + seps [i % 2 ]
183- else :
184- ret += role + " "
185- return ret
186-
187-
188- def _format_add_colon_single (
189- system_message : str , messages : List [Tuple [str , Optional [str ]]], sep : str
190- ) -> str :
191- """Format the prompt with the add-colon-single style."""
192- ret = system_message + sep
193- for role , message in messages :
194- if message :
195- ret += role + ": " + message + sep
196- else :
197- ret += role + ":"
198- return ret
199-
200-
201- def _format_add_colon_two (
202- system_message : str , messages : List [Tuple [str , Optional [str ]]], sep : str , sep2 : str
203- ) -> str :
204- """Format the prompt with the add-colon-two style."""
205- seps = [sep , sep2 ]
206- ret = system_message + seps [0 ]
207- for i , (role , message ) in enumerate (messages ):
208- if message :
209- ret += role + ": " + message + seps [i % 2 ]
210- else :
211- ret += role + ":"
212- return ret
213-
214-
215- def _format_no_colon_single (
216- system_message : str , messages : List [Tuple [str , Optional [str ]]], sep : str
217- ) -> str :
218- """Format the prompt with the no-colon-single style."""
219- ret = system_message
220- for role , message in messages :
221- if message :
222- ret += role + message + sep
223- else :
224- ret += role
225- return ret
226-
227-
228- def _format_add_colon_space_single (
229- system_message : str , messages : List [Tuple [str , Optional [str ]]], sep : str
230- ) -> str :
231- """Format the prompt with the add-colon-space-single style."""
232- ret = system_message + sep
233- for role , message in messages :
234- if message :
235- ret += role + ": " + message + sep
236- else :
237- ret += role + ": " # must be end with a space
238- return ret
239-
149+ class Jinja2ChatFormatter (ChatFormatter ):
150+ def __init__ (
151+ self ,
152+ template : str ,
153+ eos_token : str ,
154+ bos_token : str ,
155+ ):
156+ """A chat formatter that uses jinja2 templates to format the prompt."""
157+ self .template = template
158+ self .eos_token = eos_token
159+ self .bos_token = bos_token
240160
241- def _format_chatml (
242- system_message : str , messages : List [Tuple [str , Optional [str ]]], sep : str
243- ) -> str :
244- """Format the prompt with the chatml style."""
245- ret = "" if system_message == "" else system_message + sep + "\n "
246- for role , message in messages :
247- if message :
248- ret += role + "\n " + message + sep + "\n "
249- else :
250- ret += role + "\n "
251- return ret
161+ self ._environment = jinja2 .Environment (
162+ loader = jinja2 .BaseLoader (),
163+ trim_blocks = True ,
164+ lstrip_blocks = True ,
165+ ).from_string (self .template )
252166
167+ def __call__ (
168+ self ,
169+ * ,
170+ messages : List [llama_types .ChatCompletionRequestMessage ],
171+ ** kwargs : Any ,
172+ ) -> ChatFormatterResponse :
173+ messages = [
174+ * messages ,
175+ llama_types .ChatCompletionRequestAssistantMessage (
176+ role = "assistant" , content = ""
177+ ),
178+ ]
179+ prompt = self ._environment .render (
180+ messages = messages , eos_token = self .eos_token , bos_token = self .bos_token
181+ )
182+ return ChatFormatterResponse (prompt = prompt , stop = [self .eos_token ])
253183
254- def _format_chatglm3 (
255- system_message : str , messages : List [Tuple [str , Optional [str ]]], sep : str
256- ) -> str :
257- """Format the prompt with the chatglm3 style."""
258- ret = ""
259- if system_message :
260- ret += system_message
261- for role , message in messages :
262- if message :
263- ret += role + "\n " + " " + message
264- else :
265- ret += role
266- return ret
184+ def to_chat_handler (self ) -> LlamaChatCompletionHandler :
185+ return chat_formatter_to_chat_completion_handler (self )
267186
268187
269188def _convert_text_completion_to_chat (
@@ -426,16 +345,6 @@ def chat_completion_handler(
426345 return chat_completion_handler
427346
428347
429- def register_chat_format (name : str ):
430- def decorator (f : ChatFormatter ):
431- chat_completion_handler = chat_formatter_to_chat_completion_handler (f )
432- LlamaChatCompletionHandlerRegistry ().register_chat_completion_handler (
433- name , chat_completion_handler
434- )
435- return f
436- return decorator
437-
438-
439348def hf_autotokenizer_to_chat_formatter (
440349 pretrained_model_name_or_path : Union [str , os .PathLike [str ]]
441350) -> ChatFormatter :
@@ -466,7 +375,9 @@ def hf_autotokenizer_to_chat_completion_handler(
466375 return chat_formatter_to_chat_completion_handler (chat_formatter )
467376
468377
469- def hf_tokenizer_config_to_chat_formatter (tokenizer_config : Dict [str , Any ]) -> ChatFormatter :
378+ def hf_tokenizer_config_to_chat_formatter (
379+ tokenizer_config : Dict [str , Any ]
380+ ) -> ChatFormatter :
470381 assert isinstance (tokenizer_config , dict )
471382
472383 assert "chat_template" in tokenizer_config
@@ -504,6 +415,7 @@ def format_autotokenizer(
504415 eos_token = eos_token ,
505416 )
506417 return ChatFormatterResponse (prompt = prompt , stop = eos_token )
418+
507419 return format_autotokenizer
508420
509421
@@ -514,6 +426,147 @@ def hf_tokenizer_config_to_chat_completion_handler(
514426 return chat_formatter_to_chat_completion_handler (chat_formatter )
515427
516428
429+ ### Utility functions for formatting chat prompts ###
430+
431+
432+ def _get_system_message (
433+ messages : List [llama_types .ChatCompletionRequestMessage ],
434+ ) -> str :
435+ """Get the first system message."""
436+ for message in messages :
437+ if message ["role" ] == "system" :
438+ return message ["content" ] or ""
439+ return ""
440+
441+
442+ def _map_roles (
443+ messages : List [llama_types .ChatCompletionRequestMessage ],
444+ role_map : Dict [str , str ],
445+ ) -> List [Tuple [str , Optional [str ]]]:
446+ """Map the message roles."""
447+ output : List [Tuple [str , Optional [str ]]] = []
448+ for message in messages :
449+ role = message ["role" ]
450+ if role in role_map :
451+ content : str | None = (
452+ message ["content" ] if isinstance (message ["content" ], str ) else None
453+ )
454+ output .append ((role_map [role ], content ))
455+ return output
456+
457+
458+ def _format_llama2 (
459+ system_message : str , messages : List [Tuple [str , Optional [str ]]], sep : str , sep2 : str
460+ ) -> str :
461+ """Format the prompt with the llama2 style."""
462+ seps = [sep , sep2 ]
463+ ret = system_message + sep
464+ for i , (role , message ) in enumerate (messages ):
465+ if system_message and i == 0 :
466+ m = message or ""
467+ ret += m + seps [i % 2 ]
468+ elif message :
469+ ret += role + message + " " + seps [i % 2 ]
470+ else :
471+ ret += role + " "
472+ return ret
473+
474+
475+ def _format_add_colon_single (
476+ system_message : str , messages : List [Tuple [str , Optional [str ]]], sep : str
477+ ) -> str :
478+ """Format the prompt with the add-colon-single style."""
479+ ret = system_message + sep
480+ for role , message in messages :
481+ if message :
482+ ret += role + ": " + message + sep
483+ else :
484+ ret += role + ":"
485+ return ret
486+
487+
488+ def _format_add_colon_two (
489+ system_message : str , messages : List [Tuple [str , Optional [str ]]], sep : str , sep2 : str
490+ ) -> str :
491+ """Format the prompt with the add-colon-two style."""
492+ seps = [sep , sep2 ]
493+ ret = system_message + seps [0 ]
494+ for i , (role , message ) in enumerate (messages ):
495+ if message :
496+ ret += role + ": " + message + seps [i % 2 ]
497+ else :
498+ ret += role + ":"
499+ return ret
500+
501+
502+ def _format_no_colon_single (
503+ system_message : str , messages : List [Tuple [str , Optional [str ]]], sep : str
504+ ) -> str :
505+ """Format the prompt with the no-colon-single style."""
506+ ret = system_message
507+ for role , message in messages :
508+ if message :
509+ ret += role + message + sep
510+ else :
511+ ret += role
512+ return ret
513+
514+
515+ def _format_add_colon_space_single (
516+ system_message : str , messages : List [Tuple [str , Optional [str ]]], sep : str
517+ ) -> str :
518+ """Format the prompt with the add-colon-space-single style."""
519+ ret = system_message + sep
520+ for role , message in messages :
521+ if message :
522+ ret += role + ": " + message + sep
523+ else :
524+ ret += role + ": " # must be end with a space
525+ return ret
526+
527+
528+ def _format_chatml (
529+ system_message : str , messages : List [Tuple [str , Optional [str ]]], sep : str
530+ ) -> str :
531+ """Format the prompt with the chatml style."""
532+ ret = "" if system_message == "" else system_message + sep + "\n "
533+ for role , message in messages :
534+ if message :
535+ ret += role + "\n " + message + sep + "\n "
536+ else :
537+ ret += role + "\n "
538+ return ret
539+
540+
541+ def _format_chatglm3 (
542+ system_message : str , messages : List [Tuple [str , Optional [str ]]], sep : str
543+ ) -> str :
544+ """Format the prompt with the chatglm3 style."""
545+ ret = ""
546+ if system_message :
547+ ret += system_message
548+ for role , message in messages :
549+ if message :
550+ ret += role + "\n " + " " + message
551+ else :
552+ ret += role
553+ return ret
554+
555+
556+ ### Chat Formats ###
557+
558+
559+ def register_chat_format (name : str ):
560+ def decorator (f : ChatFormatter ):
561+ chat_completion_handler = chat_formatter_to_chat_completion_handler (f )
562+ LlamaChatCompletionHandlerRegistry ().register_chat_completion_handler (
563+ name , chat_completion_handler
564+ )
565+ return f
566+
567+ return decorator
568+
569+
517570# see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py
518571# system prompt is "embedded" in the first message
519572@register_chat_format ("llama-2" )
0 commit comments