Skip to content

Commit 00e5654

Browse files
Add parameters to chat template in chat example (microsoft#1673)
### Description This PR updates `apply_chat_template` to read a `chat_template.jinja` file. ### Motivation and Context This is used for [OpenAI's gpt-oss models](https://openai.com/index/introducing-gpt-oss/).
1 parent 0a721da commit 00e5654

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

cmake/deps.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ pybind11;https://github.com/pybind/pybind11/archive/refs/tags/v2.13.6.zip;f78029
1414
googletest;https://github.com/google/googletest/archive/530d5c8c84abd2a46f38583ee817743c9b3a42b4.zip;5e3a61db2aa975cfd0f97ba92c818744e7fa7034
1515
microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5
1616
directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e
17-
onnxruntime_extensions;https://github.com/microsoft/onnxruntime-extensions.git;492286d9b7b1e674a2d4ce81bd22a7668c3b58fa
17+
onnxruntime_extensions;https://github.com/microsoft/onnxruntime-extensions.git;af289f57acae13a0ee1926605e0a7cf53efd8a0c
1818

1919
# These two dependencies are for the optional constrained decoding feature (USE_GUIDANCE)
2020
llguidance;https://github.com/microsoft/llguidance.git;2d2f1de3c87e3289528affc346f734f7471216d9

examples/python/model-chat.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33

44
import onnxruntime_genai as og
55
import argparse
6-
import time
6+
import os
77
import json
8+
import time
89

910
def get_tools_list(input_tools):
1011
# input_tools format: '[{"name": "fn1", "description": "fn details", "parameters": {"p1": {"description": "details", "type": "string"}}},
@@ -134,8 +135,18 @@ def main(args):
134135
messages = f"""[{{"role": "system", "content": "{system_prompt}", "tools": "{prompt_tool_input}"}}]"""
135136
else:
136137
messages = f"""[{{"role": "system", "content": "{system_prompt}"}}]"""
138+
137139
# Apply Chat Template
138-
tokenizer_input_system_prompt = tokenizer.apply_chat_template(messages=messages, add_generation_prompt=False)
140+
template_str = ""
141+
tokenizer_input_system_prompt = None
142+
jinja_path = os.path.join(args.model_path, "chat_template.jinja")
143+
if os.path.exists(jinja_path):
144+
with open(jinja_path, "r", encoding="utf-8") as f:
145+
template_str = f.read()
146+
tokenizer_input_system_prompt = tokenizer.apply_chat_template(messages=messages, add_generation_prompt=False, template_str=template_str)
147+
else:
148+
tokenizer_input_system_prompt = tokenizer.apply_chat_template(messages=messages, add_generation_prompt=False)
149+
139150
input_tokens = tokenizer.encode(tokenizer_input_system_prompt)
140151
# Ignoring the last end of text token as it is messes up the generation when grammar is enabled
141152
if guidance_type:
@@ -156,8 +167,13 @@ def main(args):
156167
if args.timings: started_timestamp = time.time()
157168

158169
messages = f"""[{{"role": "user", "content": "{text}"}}]"""
170+
159171
# Apply Chat Template
160-
user_prompt = tokenizer.apply_chat_template(messages=messages, add_generation_prompt=True)
172+
user_prompt = ""
173+
if os.path.exists(jinja_path):
174+
user_prompt = tokenizer.apply_chat_template(messages=messages, add_generation_prompt=True, template_str=template_str)
175+
else:
176+
user_prompt = tokenizer.apply_chat_template(messages=messages, add_generation_prompt=True)
161177
input_tokens = tokenizer.encode(user_prompt)
162178
generator.append_tokens(input_tokens)
163179

0 commit comments

Comments
 (0)