Skip to content

Commit d6cb286

Browse files
authored
[Test] Basic checks on chat templates (#944)
1 parent 22e744b commit d6cb286

File tree

2 files changed

+82
-10
lines changed

2 files changed

+82
-10
lines changed

guidance/chat.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -190,16 +190,16 @@ class Phi3MiniChatTemplate(ChatTemplate):
190190

191191
def get_role_start(self, role_name):
192192
if role_name == "user":
193-
return "<|user|>"
193+
return "<|user|>\n"
194194
elif role_name == "assistant":
195-
return "<|assistant|>"
195+
return "<|assistant|>\n"
196196
elif role_name == "system":
197-
return "<|system|>"
197+
return "<|system|>\n"
198198
else:
199199
raise UnsupportedRoleException(role_name, self)
200200

201201
def get_role_end(self, role_name=None):
202-
return "<|end|>"
202+
return "<|end|>\n"
203203

204204

205205
CHAT_TEMPLATE_CACHE[phi3_mini_template] = Phi3MiniChatTemplate
@@ -219,14 +219,14 @@ class Phi3SmallMediumChatTemplate(ChatTemplate):
219219

220220
def get_role_start(self, role_name):
221221
if role_name == "user":
222-
return "<|user|>"
222+
return "<|user|>\n"
223223
elif role_name == "assistant":
224-
return "<|assistant|>"
224+
return "<|assistant|>\n"
225225
else:
226226
raise UnsupportedRoleException(role_name, self)
227227

228228
def get_role_end(self, role_name=None):
229-
return "<|end|>"
229+
return "<|end|>\n"
230230

231231

232232
CHAT_TEMPLATE_CACHE[phi3_small_template] = Phi3SmallMediumChatTemplate

tests/need_credentials/test_chat_templates.py

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import pytest
2+
import transformers
3+
4+
import guidance
25

36
from guidance.chat import CHAT_TEMPLATE_CACHE
4-
import transformers
57

68
from ..utils import env_or_fail
79

@@ -23,8 +25,6 @@ def test_popular_models_in_cache(model_id: str, should_pass: bool):
2325
# If this fails, the models have had their templates updated, and we need to fix the cache manually.
2426
hf_token = env_or_fail("HF_TOKEN")
2527

26-
# model_id, should_pass = model_info
27-
2828
tokenizer = transformers.AutoTokenizer.from_pretrained(
2929
model_id, token=hf_token, trust_remote_code=True
3030
)
@@ -38,3 +38,75 @@ def test_popular_models_in_cache(model_id: str, should_pass: bool):
3838

3939
# TODO: Expand testing to verify that tokenizer.apply_chat_template() produces same results as our ChatTemplate subclasses
4040
# once I hook up the new ChatTemplate to guidance.models.Transformers and guidance.models.LlamaCPP, we can do this
41+
42+
43+
@pytest.mark.parametrize(
44+
"model_id",
45+
[
46+
"microsoft/Phi-3-mini-4k-instruct",
47+
"microsoft/Phi-3-small-8k-instruct",
48+
"microsoft/Phi-3-medium-4k-instruct",
49+
"meta-llama/Meta-Llama-3-8B-Instruct",
50+
"meta-llama/Llama-2-7b-chat-hf",
51+
"mistralai/Mistral-7B-Instruct-v0.2",
52+
],
53+
)
54+
def test_chat_format_smoke(model_id: str):
55+
hf_token = env_or_fail("HF_TOKEN")
56+
57+
tokenizer = transformers.AutoTokenizer.from_pretrained(
58+
model_id, token=hf_token, trust_remote_code=True
59+
)
60+
model_chat_template = tokenizer.chat_template
61+
62+
lm = guidance.models.Mock("")
63+
lm.chat_template = CHAT_TEMPLATE_CACHE[model_chat_template]()
64+
65+
messages = [
66+
{"role": "user", "content": "Good day to you!"},
67+
{"role": "assistant", "content": "Hello!"},
68+
]
69+
tokeniser_render = tokenizer.apply_chat_template(messages, tokenize=False)
70+
71+
with guidance.user():
72+
lm += "Good day to you!"
73+
with guidance.assistant():
74+
lm += "Hello!"
75+
# Only check substring due to BOS/EOS tokens
76+
assert str(lm) in tokeniser_render
77+
78+
79+
@pytest.mark.parametrize(
80+
"model_id",
81+
[
82+
"microsoft/Phi-3-mini-4k-instruct",
83+
"meta-llama/Meta-Llama-3-8B-Instruct",
84+
"meta-llama/Llama-2-7b-chat-hf",
85+
],
86+
)
87+
def test_chat_format_smoke_with_system(model_id: str):
88+
hf_token = env_or_fail("HF_TOKEN")
89+
90+
tokenizer = transformers.AutoTokenizer.from_pretrained(
91+
model_id, token=hf_token, trust_remote_code=True
92+
)
93+
model_chat_template = tokenizer.chat_template
94+
95+
lm = guidance.models.Mock("")
96+
lm.chat_template = CHAT_TEMPLATE_CACHE[model_chat_template]()
97+
98+
messages = [
99+
{"role": "system", "content": "You are an LLM"},
100+
{"role": "user", "content": "Good day to you!"},
101+
{"role": "assistant", "content": "Hello!"},
102+
]
103+
tokeniser_render = tokenizer.apply_chat_template(messages, tokenize=False)
104+
105+
with guidance.system():
106+
lm += "You are an LLM"
107+
with guidance.user():
108+
lm += "Good day to you!"
109+
with guidance.assistant():
110+
lm += "Hello!"
111+
# Only check substring due to BOS/EOS tokens
112+
assert str(lm) in tokeniser_render

0 commit comments

Comments
 (0)