@@ -37,7 +37,8 @@ def __contains__(self, key: str):
3737
3838# Feels weird having to instantiate this, but it's a singleton for all purposes
3939# TODO [HN]: Add an alias system so we can instantiate with other simple keys (e.g. "llama2" instead of the full template string)
40- CHAT_TEMPLATE_CACHE = ChatTemplateCache ()
40+ CHAT_TEMPLATE_CACHE = ChatTemplateCache ()
41+
4142
4243class UnsupportedRoleException (Exception ):
4344 def __init__ (self , role_name , instance ):
@@ -46,11 +47,12 @@ def __init__(self, role_name, instance):
4647 super ().__init__ (self ._format_message ())
4748
4849 def _format_message (self ):
49- return (f"Role { self .role_name } is not supported by the { self .instance .__class__ .__name__ } chat template. " )
50+ return f"Role { self .role_name } is not supported by the { self .instance .__class__ .__name__ } chat template. "
51+
5052
5153def load_template_class (chat_template = None ):
5254 """Utility method to find the best chat template.
53-
55+
5456 Order of precedence:
5557 - If it's a chat template class, use it directly
5658 - If it's a string, check the cache of popular model templates
@@ -60,23 +62,27 @@ def load_template_class(chat_template=None):
6062 """
6163 if inspect .isclass (chat_template ) and issubclass (chat_template , ChatTemplate ):
6264 if chat_template is ChatTemplate :
63- raise Exception ("You can't use the base ChatTemplate class directly. Create or use a subclass instead." )
65+ raise Exception (
66+ "You can't use the base ChatTemplate class directly. Create or use a subclass instead."
67+ )
6468 return chat_template
65-
69+
6670 elif isinstance (chat_template , str ):
6771 # First check the cache of popular model types
6872 # TODO: Expand keys of cache to include aliases for popular model types (e.g. "llama2, phi3")
6973 # Can possibly accomplish this with an "aliases" dictionary that maps all aliases to the canonical key in cache
7074 if chat_template in CHAT_TEMPLATE_CACHE :
7175 return CHAT_TEMPLATE_CACHE [chat_template ]
7276 # TODO: Add logic here to try to auto-create class dynamically via _template_class_from_string method
73-
77+
7478 # Only warn when a user provided a chat template that we couldn't load
7579 if chat_template is not None :
76- warnings .warn (f"""Chat template { chat_template } was unable to be loaded directly into guidance.
80+ warnings .warn (
81+ f"""Chat template { chat_template } was unable to be loaded directly into guidance.
7782 Defaulting to the ChatML format which may not be optimal for the selected model.
78- For best results, create and pass in a `guidance.ChatTemplate` subclass for your model.""" )
79-
83+ For best results, create and pass in a `guidance.ChatTemplate` subclass for your model."""
84+ )
85+
8086 # By default, use the ChatML Template. Warnings to user will happen downstream only if they use chat roles.
8187 return ChatMLTemplate
8288
@@ -94,15 +100,18 @@ def _template_class_from_string(template_str):
94100# --------------------------------------------------
95101# Note that all grammarless models will default to this syntax, since we typically send chat formatted messages.
96102chatml_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n ' + message['content'] + '<|im_end|>' + '\n '}}{% endfor %}"
103+
104+
97105class ChatMLTemplate (ChatTemplate ):
98106 template_str = chatml_template
99107
100108 def get_role_start (self , role_name ):
101109 return f"<|im_start|>{ role_name } \n "
102-
110+
103111 def get_role_end (self , role_name = None ):
104112 return "<|im_end|>\n "
105113
114+
106115CHAT_TEMPLATE_CACHE [chatml_template ] = ChatMLTemplate
107116
108117
@@ -111,6 +120,8 @@ def get_role_end(self, role_name=None):
111120# --------------------------------------------------
112121# [05/08/24] https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/tokenizer_config.json#L12
113122llama2_template = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\ n' + system_message + '\\ n<</SYS>>\\ n\\ n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"
123+
124+
114125class Llama2ChatTemplate (ChatTemplate ):
115126 # available_roles = ["system", "user", "assistant"]
116127 template_str = llama2_template
@@ -124,7 +135,7 @@ def get_role_start(self, role_name):
124135 return " "
125136 else :
126137 raise UnsupportedRoleException (role_name , self )
127-
138+
128139 def get_role_end (self , role_name = None ):
129140 if role_name == "system" :
130141 return "\n <</SYS>"
@@ -135,6 +146,7 @@ def get_role_end(self, role_name=None):
135146 else :
136147 raise UnsupportedRoleException (role_name , self )
137148
149+
138150CHAT_TEMPLATE_CACHE [llama2_template ] = Llama2ChatTemplate
139151
140152
@@ -143,6 +155,8 @@ def get_role_end(self, role_name=None):
143155# --------------------------------------------------
144156# [05/08/24] https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json#L2053
145157llama3_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n \n '+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n \n ' }}{% endif %}"
158+
159+
146160class Llama3ChatTemplate (ChatTemplate ):
147161 # available_roles = ["system", "user", "assistant"]
148162 template_str = llama3_template
@@ -156,52 +170,89 @@ def get_role_start(self, role_name):
156170 return "<|start_header_id|>assistant<|end_header_id|>\n \n "
157171 else :
158172 raise UnsupportedRoleException (role_name , self )
159-
173+
160174 def get_role_end (self , role_name = None ):
161175 return "<|eot_id|>"
162176
177+
163178CHAT_TEMPLATE_CACHE [llama3_template ] = Llama3ChatTemplate
164179
165180# --------------------------------------------------
166181# @@@@ Phi-3 @@@@
167182# --------------------------------------------------
168183# [05/08/24] https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/tokenizer_config.json#L119
169- phi3_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n ' + message['content'] + '<|end|>' + '\n ' + '<|assistant|>' + '\n '}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n '}}{% endif %}{% endfor %}"
170- class Phi3ChatTemplate (ChatTemplate ):
184+ phi3_mini_template = "{% for message in messages %}{% if message['role'] == 'system' %}{{'<|system|>\n ' + message['content'] + '<|end|>\n '}}{% elif message['role'] == 'user' %}{{'<|user|>\n ' + message['content'] + '<|end|>\n '}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n ' + message['content'] + '<|end|>\n '}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n ' }}{% else %}{{ eos_token }}{% endif %}"
185+
186+
187+ class Phi3MiniChatTemplate (ChatTemplate ):
171188 # available_roles = ["user", "assistant"]
172- template_str = phi3_template
189+ template_str = phi3_mini_template
173190
174191 def get_role_start (self , role_name ):
175192 if role_name == "user" :
176193 return "<|user|>"
177194 elif role_name == "assistant" :
178195 return "<|assistant|>"
196+ elif role_name == "system" :
197+ return "<|system|>"
179198 else :
180199 raise UnsupportedRoleException (role_name , self )
181-
200+
182201 def get_role_end (self , role_name = None ):
183202 return "<|end|>"
184203
185- CHAT_TEMPLATE_CACHE [phi3_template ] = Phi3ChatTemplate
186204
205+ CHAT_TEMPLATE_CACHE [phi3_mini_template ] = Phi3MiniChatTemplate
206+
207+ # https://huggingface.co/microsoft/Phi-3-small-8k-instruct/blob/main/tokenizer_config.json
208+ phi3_small_template = "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n ' + message['content'] + '<|end|>\n ' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n ' }}{% else %}{{ eos_token }}{% endif %}"
209+
210+
211+ # https://huggingface.co/microsoft/Phi-3-medium-4k-instruct/blob/main/tokenizer_config.json#L119
212+ phi3_medium_template = "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n ' + message['content'] + '<|end|>' + '\n ' + '<|assistant|>' + '\n '}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n '}}{% endif %}{% endfor %}"
213+
214+
215+ # Although the templates are different, the roles are the same between medium and small (for now)
216+ class Phi3SmallMediumChatTemplate (ChatTemplate ):
217+ # available_roles = ["user", "assistant"]
218+ template_str = phi3_small_template
219+
220+ def get_role_start (self , role_name ):
221+ if role_name == "user" :
222+ return "<|user|>"
223+ elif role_name == "assistant" :
224+ return "<|assistant|>"
225+ else :
226+ raise UnsupportedRoleException (role_name , self )
227+
228+ def get_role_end (self , role_name = None ):
229+ return "<|end|>"
230+
231+
232+ CHAT_TEMPLATE_CACHE [phi3_small_template ] = Phi3SmallMediumChatTemplate
233+ CHAT_TEMPLATE_CACHE [phi3_medium_template ] = Phi3SmallMediumChatTemplate
187234
188235# --------------------------------------------------
189236# @@@@ Mistral-7B-Instruct-v0.2 @@@@
190237# --------------------------------------------------
191238# [05/08/24] https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/blob/main/tokenizer_config.json#L42
192- mistral_7b_instruct_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
239+ mistral_7b_instruct_template = "{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_messages = messages[1:] %}\n {%- else %}\n {%- set loop_messages = messages %}\n {%- endif %}\n \n {{- bos_token }}\n {%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {%- if loop.first and system_message is defined %}\n {{- ' [INST] ' + system_message + '\\ n\\ n' + message['content'] + ' [/INST]' }}\n {%- else %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {{- ' ' + message['content'] + eos_token}}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n {%- endfor %}\n "
240+
241+
193242class Mistral7BInstructChatTemplate (ChatTemplate ):
194243 # available_roles = ["user", "assistant"]
195244 template_str = mistral_7b_instruct_template
196245
197246 def get_role_start (self , role_name ):
198247 if role_name == "user" :
199- return "[INST] "
248+ return " [INST] "
200249 elif role_name == "assistant" :
201250 return " "
251+ elif role_name == "system" :
252+ raise ValueError ("Please include system instructions in the first user message" )
202253 else :
203254 raise UnsupportedRoleException (role_name , self )
204-
255+
205256 def get_role_end (self , role_name = None ):
206257 if role_name == "user" :
207258 return " [/INST]"
@@ -210,4 +261,5 @@ def get_role_end(self, role_name=None):
210261 else :
211262 raise UnsupportedRoleException (role_name , self )
212263
264+
213265CHAT_TEMPLATE_CACHE [mistral_7b_instruct_template ] = Mistral7BInstructChatTemplate
0 commit comments