Skip to content

Commit 017ac4a

Browse files
committed
feat: handle the deprecated default chat template
1 parent dc1f076 commit 017ac4a

File tree

3 files changed

+28
-4
lines changed

3 files changed

+28
-4
lines changed

app/api/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ async def generate_text(
352352
chat_template="{% for message in messages %}\n{% if message['role'] == 'user' %}\nUser: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}\nAssistant: {{ message['content'] }}\n{% endif %}\n{% endfor %}\nAssistant:",
353353
tokenize=True,
354354
)
355-
prompt_obj = TokensPrompt(prompt_token_ids=prompt_tokens)
355+
prompt_obj = TokensPrompt(prompt_token_ids=prompt_tokens) # type: ignore
356356

357357
async def _stream() -> AsyncGenerator[bytes, None]:
358358
start = 0

app/utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,14 @@ def get_prompt_from_messages(tokenizer: PreTrainedTokenizer, messages: List[Prom
756756
tokenize=False,
757757
add_generation_prompt=True,
758758
)
759+
elif hasattr(tokenizer, "default_chat_template") and tokenizer.default_chat_template:
760+
# This largely depends on how older versions of HF tokenizers behave and may not work universally
761+
tokenizer.chat_template = tokenizer.default_chat_template
762+
prompt = tokenizer.apply_chat_template(
763+
[dump_pydantic_object_to_dict(message) for message in messages],
764+
tokenize=False,
765+
add_generation_prompt=True,
766+
)
759767
else:
760768
system_content = ""
761769
prompt_parts: List[str] = []
@@ -835,4 +843,3 @@ def get_prompt_from_messages(tokenizer: PreTrainedTokenizer, messages: List[Prom
835843
"25624495": '© 2002-2020 International Health Terminology Standards Development Organisation (IHTSDO). All rights reserved. SNOMED CT®, was originally created by The College of American Pathologists. "SNOMED" and "SNOMED CT" are registered trademarks of the IHTSDO.',
836844
"55540447": "linkage concept"
837845
}
838-

tests/app/test_utils.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -412,10 +412,27 @@ def test_get_prompt_with_chat_template():
412412
assert prompt == "Mock chat template applied"
413413

414414

415+
def test_get_prompt_with_default_chat_template():
416+
with patch('transformers.PreTrainedTokenizer') as tok:
417+
mock_tokenizer = tok.return_value
418+
mock_tokenizer.chat_template = None
419+
mock_tokenizer.default_chat_template = "Mock default chat template"
420+
mock_tokenizer.apply_chat_template.return_value = "Mock default chat template applied"
421+
messages = [
422+
PromptMessage(content="Alright?", role=PromptRole.USER.value),
423+
PromptMessage(content="Yeah.", role=PromptRole.ASSISTANT.value),
424+
]
425+
426+
prompt = get_prompt_from_messages(mock_tokenizer, messages)
427+
428+
assert prompt == "Mock default chat template applied"
429+
430+
415431
def test_get_prompt_without_chat_template():
416432
with patch('transformers.PreTrainedTokenizer') as tok:
417433
mock_tokenizer = tok.return_value
418434
mock_tokenizer.chat_template = None
435+
mock_tokenizer.default_chat_template = None
419436
messages = [
420437
PromptMessage(content="You are a helpful assistant.", role=PromptRole.SYSTEM.value),
421438
PromptMessage(content="Alright?", role=PromptRole.USER.value),
@@ -432,9 +449,9 @@ def test_get_prompt_with_no_messages():
432449
with patch('transformers.PreTrainedTokenizer') as tok:
433450
mock_tokenizer = tok.return_value
434451
mock_tokenizer.chat_template = None
452+
mock_tokenizer.default_chat_template = None
435453
messages = []
436454

437455
prompt = get_prompt_from_messages(mock_tokenizer, messages)
438456

439-
expected_prompt = "\n<|assistant|>\n"
440-
assert prompt == expected_prompt
457+
assert prompt == "\n<|assistant|>\n"

0 commit comments

Comments
 (0)