Skip to content

Commit ff429f4

Browse files
committed
feat: add the chat template factory
1 parent 017ac4a commit ff429f4

File tree

3 files changed

+530
-26
lines changed

3 files changed

+530
-26
lines changed

app/processors/prompt_factory.py

Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
class PromptFactory:
2+
3+
_ALPACA = (
4+
"{% if messages[0]['role'] == 'system' %}"
5+
"{% set loop_messages = messages[1:] %}"
6+
"{% set system_message = messages[0]['content'].strip() + '\n' %}"
7+
"{% else %}"
8+
"{% set loop_messages = messages %}"
9+
"{% set system_message = '' %}"
10+
"{% endif %}"
11+
"{% for message in loop_messages %}"
12+
"{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
13+
"{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
14+
"{% endif %}"
15+
"{% if loop.index0 == 0 %}"
16+
"{% set content = system_message + message['content'] %}"
17+
"{% else %}"
18+
"{% set content = message['content'] %}"
19+
"{% endif %}"
20+
21+
"{% if message['role'] == 'user' %}"
22+
"{{ '### Instruction:\n' + content.strip() + '\n\n'}}"
23+
"{% elif message['role'] == 'assistant' %}"
24+
"{{ '### Response:\n' + content.strip() + '\n\n' }}"
25+
"{% endif %}"
26+
"{% endfor %}"
27+
"{% if add_generation_prompt %}"
28+
"{{ '### Response:\n' }}"
29+
"{% endif %}"
30+
)
31+
32+
_CHAT_ML = (
33+
"{% for message in messages %}"
34+
"{{'<|im_start|>' + message['role'] + '\n' + message['content'].strip() + '<|im_end|>' + '\n'}}"
35+
"{% endfor %}"
36+
"{% if add_generation_prompt %}"
37+
"{{'<|im_start|>assistant\n'}}"
38+
"{% endif %}"
39+
)
40+
41+
_DEFAULT = (
42+
"{% for message in messages %}"
43+
"{% if message['role'] == 'user' %}"
44+
"{{'<|user|>\n' + message['content'] + eos_token}}"
45+
"{% elif message['role'] == 'system' %}"
46+
"{{'<|system|>\n' + message['content'] + eos_token}}"
47+
"{% elif message['role'] == 'assistant' %}"
48+
"{{'<|assistant|>\n' + message['content'] + eos_token}}"
49+
"{% endif %}"
50+
"{% if loop.last and add_generation_prompt %}"
51+
"{{'<|assistant|>'}}"
52+
"{% endif %}"
53+
"{% endfor %}"
54+
)
55+
56+
_FALCON = (
57+
"{% if messages[0]['role'] == 'system' %}"
58+
"{% set loop_messages = messages[1:] %}"
59+
"{% set system_message = messages[0]['content'] %}"
60+
"{% else %}"
61+
"{% set loop_messages = messages %}"
62+
"{% set system_message = '' %}"
63+
"{% endif %}"
64+
"{% for message in loop_messages %}"
65+
"{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
66+
"{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
67+
"{% endif %}"
68+
"{% if loop.index0 == 0 %}"
69+
"{{ system_message.strip() }}"
70+
"{% endif %}"
71+
"{{ '\n\n' + message['role'].title() + ': ' + message['content'].strip().replace('\r\n', '\n').replace('\n\n', '\n') }}"
72+
"{% endfor %}"
73+
"{% if add_generation_prompt %}"
74+
"{ '\n\nAssistant:' }}"
75+
"{% endif %}"
76+
)
77+
78+
_GEMMA = (
79+
"{% if messages[0]['role'] == 'system' %}"
80+
"{% set loop_messages = messages[1:] %}"
81+
"{% set system_message = messages[0]['content'].strip() + '\n\n' %}"
82+
"{% else %}"
83+
"{% set loop_messages = messages %}"
84+
"{% set system_message = '' %}"
85+
"{% endif %}"
86+
"{% for message in loop_messages %}"
87+
"{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
88+
"{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
89+
"{% endif %}"
90+
"{% if loop.index0 == 0 %}"
91+
"{% set content = system_message + message['content'] %}"
92+
"{% else %}"
93+
"{% set content = message['content'] %}"
94+
"{% endif %}"
95+
"{% if (message['role'] == 'assistant') %}"
96+
"{% set role = 'model' %}"
97+
"{% else %}"
98+
"{% set role = message['role'] %}"
99+
"{% endif %}"
100+
"{{ '<start_of_turn>' + role + '\n' + content.strip() + '<end_of_turn>\n' }}"
101+
"{% endfor %}"
102+
"{% if add_generation_prompt %}"
103+
"{{'<start_of_turn>model\n'}}"
104+
"{% endif %}"
105+
)
106+
107+
_LLAMA_2 = (
108+
"{% if messages[0]['role'] == 'system' %}"
109+
"{% set loop_messages = messages[1:] %}"
110+
"{% set system_message = '<<SYS>>\n' + messages[0]['content'].strip() + '\n<</SYS>>\n\n' %}"
111+
"{% else %}"
112+
"{% set loop_messages = messages %}"
113+
"{% set system_message = '' %}"
114+
"{% endif %}"
115+
"{% for message in loop_messages %}"
116+
"{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
117+
"{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
118+
"{% endif %}"
119+
"{% if loop.index0 == 0 %}"
120+
"{% set content = system_message + message['content'] %}"
121+
"{% else %}"
122+
"{% set content = message['content'] %}"
123+
"{% endif %}"
124+
"{% if message['role'] == 'user' %}"
125+
"{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}"
126+
"{% elif message['role'] == 'assistant' %}"
127+
"{{ ' ' + content.strip() + ' ' + eos_token }}"
128+
"{% endif %}"
129+
"{% endfor %}"
130+
)
131+
132+
_LLAMA_3 = (
133+
"{{ bos_token }}"
134+
"{% if messages[0]['role'] == 'system' %}"
135+
"{% set loop_messages = messages[1:] %}"
136+
"{% set system_message = '<|start_header_id|>' + 'system' + '<|end_header_id|>\n\n' + messages[0]['content'].strip() + '<|eot_id|>' %}"
137+
"{% else %}"
138+
"{% set loop_messages = messages %}"
139+
"{% set system_message = '' %}"
140+
"{% endif %}"
141+
"{% for message in loop_messages %}"
142+
"{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
143+
"{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
144+
"{% endif %}"
145+
"{% if loop.index0 == 0 %}"
146+
"{{ system_message }}"
147+
"{% endif %}"
148+
"{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'].strip() + '<|eot_id|>' }}"
149+
"{% if loop.last and message['role'] == 'user' and add_generation_prompt %}"
150+
"{{ '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }}"
151+
"{% endif %}"
152+
"{% endfor %}"
153+
)
154+
155+
_MISTRAL = (
156+
"{{ bos_token }}"
157+
"{% if messages[0]['role'] == 'system' %}"
158+
"{% set loop_messages = messages[1:] %}"
159+
"{% set system_message = messages[0]['content'].strip() + '\n\n' %}"
160+
"{% else %}"
161+
"{% set loop_messages = messages %}"
162+
"{% set system_message = '' %}"
163+
"{% endif %}"
164+
"{% for message in loop_messages %}"
165+
"{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
166+
"{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
167+
"{% endif %}"
168+
"{% if loop.index0 == 0 %}"
169+
"{% set content = system_message + message['content'] %}"
170+
"{% else %}"
171+
"{% set content = message['content'] %}"
172+
"{% endif %}"
173+
"{% if message['role'] == 'user' %}"
174+
"{{ '[INST] ' + content.strip() + ' [/INST]' }}"
175+
"{% elif message['role'] == 'assistant' %}"
176+
"{{ content.strip() + eos_token}}"
177+
"{% endif %}"
178+
"{% endfor %}"
179+
)
180+
181+
_PHI_2 = (
182+
"{% if messages[0]['role'] == 'system' %}"
183+
"{% set loop_messages = messages[1:] %}"
184+
"{% set system_message = messages[0]['content'].strip() + '\n\n' %}"
185+
"{% else %}"
186+
"{% set loop_messages = messages %}"
187+
"{% set system_message = '' %}"
188+
"{% endif %}"
189+
"{% for message in loop_messages %}"
190+
"{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
191+
"{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
192+
"{% endif %}"
193+
"{% if loop.index0 == 0 %}"
194+
"{% set content = system_message + message['content'] %}"
195+
"{% else %}"
196+
"{% set content = message['content'] %}"
197+
"{% endif %}"
198+
"{% if message['role'] == 'user' %}"
199+
"{{ 'Instruct: ' + content.strip() + '\n' }}"
200+
"{% elif message['role'] == 'assistant' %}"
201+
"{{ 'Output: ' + content.strip() + '\n' }}"
202+
"{% endif %}"
203+
"{% endfor %}"
204+
"{% if add_generation_prompt %}"
205+
"{{ 'Output:' }}"
206+
"{% endif %}"
207+
)
208+
209+
_PHI_3 = (
210+
"{{ bos_token }}"
211+
"{% for message in messages %}"
212+
"{% if (message['role'] == 'system') %}"
213+
"{{'<|system|>' + '\n' + message['content'].strip() + '<|end|>' + '\n'}}"
214+
"{% elif (message['role'] == 'user') %}"
215+
"{{'<|user|>' + '\n' + message['content'].strip() + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}"
216+
"{% elif message['role'] == 'assistant' %}"
217+
"{{message['content'].strip() + '<|end|>' + '\n'}}"
218+
"{% endif %}"
219+
"{% endfor %}"
220+
)
221+
222+
_QWEN = (
223+
"{% for message in messages %}"
224+
"{% if loop.first and messages[0]['role'] != 'system' %}"
225+
"{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}"
226+
"{% endif %}"
227+
"{{'<|im_start|>' + message['role'] + '\n' + message['content'].strip() }}"
228+
"{% if (loop.last and add_generation_prompt) or not loop.last %}"
229+
"{{ '<|im_end|>' + '\n'}}"
230+
"{% endif %}"
231+
"{% endfor %}"
232+
"{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}"
233+
"{{ '<|im_start|>assistant\n' }}"
234+
"{% endif %}"
235+
)
236+
237+
@classmethod
238+
def create_chat_template(cls, name: str = "default") -> str:
239+
if name.lower() == "default":
240+
return cls._DEFAULT
241+
elif name.lower() == "alpaca":
242+
return cls._ALPACA
243+
elif name.lower() == "chat_ml":
244+
return cls._CHAT_ML
245+
elif name.lower() == "falcon":
246+
return cls._FALCON
247+
elif name.lower() == "gemma":
248+
return cls._GEMMA
249+
elif name.lower() == "llama_2":
250+
return cls._LLAMA_2
251+
elif name.lower() == "llama_3":
252+
return cls._LLAMA_3
253+
elif name.lower() == "mistral":
254+
return cls._MISTRAL
255+
elif name.lower() == "phi_2":
256+
return cls._PHI_2
257+
elif name.lower() == "phi_3":
258+
return cls._PHI_3
259+
elif name.lower() == "qwen":
260+
return cls._QWEN
261+
else:
262+
raise ValueError("Invalid template name")

app/utils.py

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from app.config import Settings
2828
from app.domain import Annotation, Entity, CodeType, ModelType, Device, PromptMessage, PromptRole
2929
from app.exception import ManagedModelException
30+
from app.processors.prompt_factory import PromptFactory
3031

3132

3233
@lru_cache
@@ -739,47 +740,60 @@ def download_model_package(
739740
retry_delay *= 2
740741

741742

742-
def get_prompt_from_messages(tokenizer: PreTrainedTokenizer, messages: List[PromptMessage]) -> str:
743+
def get_prompt_from_messages(
744+
tokenizer: PreTrainedTokenizer,
745+
messages: List[PromptMessage],
746+
override_template: Optional[str] = None,
747+
) -> str:
743748
"""
744749
Generates a prompt from a list of prompt messages.
745750
746751
Args:
747752
tokenizer (PreTrainedTokenizer): The tokenizer to use for applying the chat template.
748753
messages (List[PromptMessage]): The list of prompt messages to use for generating the prompt.
754+
override_template (str): The name of the chat template to use for generating the prompt.
749755
750756
Returns:
751757
str: The generated prompt.
752758
"""
753-
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template:
754-
prompt = tokenizer.apply_chat_template(
755-
[dump_pydantic_object_to_dict(message) for message in messages],
756-
tokenize=False,
757-
add_generation_prompt=True,
758-
)
759-
elif hasattr(tokenizer, "default_chat_template") and tokenizer.default_chat_template:
760-
# This largely depends on how older versions of HF tokenizers behave and may not work universally
761-
tokenizer.chat_template = tokenizer.default_chat_template
759+
if override_template is None:
760+
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template:
761+
prompt = tokenizer.apply_chat_template(
762+
[dump_pydantic_object_to_dict(message) for message in messages],
763+
tokenize=False,
764+
add_generation_prompt=True,
765+
)
766+
elif hasattr(tokenizer, "default_chat_template") and tokenizer.default_chat_template:
767+
# This largely depends on how older versions of HF tokenizers behave and may not work universally
768+
tokenizer.chat_template = tokenizer.default_chat_template
769+
prompt = tokenizer.apply_chat_template(
770+
[dump_pydantic_object_to_dict(message) for message in messages],
771+
tokenize=False,
772+
add_generation_prompt=True,
773+
)
774+
else:
775+
system_content = ""
776+
prompt_parts: List[str] = []
777+
for message in messages:
778+
content = message.content.strip()
779+
if message.role == PromptRole.SYSTEM:
780+
system_content = content
781+
elif message.role == PromptRole.USER:
782+
prompt_parts.append(f"<|user|>\n{content}</s>")
783+
elif message.role == PromptRole.ASSISTANT:
784+
prompt_parts.append(f"<|assistant|>\n{content}</s>")
785+
if system_content:
786+
prompt = f"<|system|>\n{system_content}</s>\n" + "\n".join(prompt_parts)
787+
else:
788+
prompt = "\n".join(prompt_parts)
789+
prompt += "\n<|assistant|>\n"
790+
else:
791+
tokenizer.chat_template = PromptFactory.create_chat_template(name=override_template)
762792
prompt = tokenizer.apply_chat_template(
763793
[dump_pydantic_object_to_dict(message) for message in messages],
764794
tokenize=False,
765795
add_generation_prompt=True,
766796
)
767-
else:
768-
system_content = ""
769-
prompt_parts: List[str] = []
770-
for message in messages:
771-
content = message.content.strip()
772-
if message.role == PromptRole.SYSTEM:
773-
system_content = content
774-
elif message.role == PromptRole.USER:
775-
prompt_parts.append(f"<|user|>\n{content}</s>")
776-
elif message.role == PromptRole.ASSISTANT:
777-
prompt_parts.append(f"<|assistant|>\n{content}</s>")
778-
if system_content:
779-
prompt = f"<|system|>\n{system_content}</s>\n" + "\n".join(prompt_parts)
780-
else:
781-
prompt = "\n".join(prompt_parts)
782-
prompt += "\n<|assistant|>\n"
783797
return prompt
784798

785799
TYPE_ID_TO_NAME_PATCH = {

0 commit comments

Comments
 (0)