3
3
4
4
import onnxruntime_genai as og
5
5
import argparse
6
- import time
6
+ import os
7
7
import json
8
+ import time
8
9
9
10
def get_tools_list (input_tools ):
10
11
# input_tools format: '[{"name": "fn1", "description": "fn details", "parameters": {"p1": {"description": "details", "type": "string"}}},
@@ -134,8 +135,18 @@ def main(args):
134
135
messages = f"""[{{"role": "system", "content": "{ system_prompt } ", "tools": "{ prompt_tool_input } "}}]"""
135
136
else :
136
137
messages = f"""[{{"role": "system", "content": "{ system_prompt } "}}]"""
138
+
137
139
# Apply Chat Template
138
- tokenizer_input_system_prompt = tokenizer .apply_chat_template (messages = messages , add_generation_prompt = False )
140
+ template_str = ""
141
+ tokenizer_input_system_prompt = None
142
+ jinja_path = os .path .join (args .model_path , "chat_template.jinja" )
143
+ if os .path .exists (jinja_path ):
144
+ with open (jinja_path , "r" , encoding = "utf-8" ) as f :
145
+ template_str = f .read ()
146
+ tokenizer_input_system_prompt = tokenizer .apply_chat_template (messages = messages , add_generation_prompt = False , template_str = template_str )
147
+ else :
148
+ tokenizer_input_system_prompt = tokenizer .apply_chat_template (messages = messages , add_generation_prompt = False )
149
+
139
150
input_tokens = tokenizer .encode (tokenizer_input_system_prompt )
140
151
# Ignoring the last end of text token as it is messes up the generation when grammar is enabled
141
152
if guidance_type :
@@ -156,8 +167,13 @@ def main(args):
156
167
if args .timings : started_timestamp = time .time ()
157
168
158
169
messages = f"""[{{"role": "user", "content": "{ text } "}}]"""
170
+
159
171
# Apply Chat Template
160
- user_prompt = tokenizer .apply_chat_template (messages = messages , add_generation_prompt = True )
172
+ user_prompt = ""
173
+ if os .path .exists (jinja_path ):
174
+ user_prompt = tokenizer .apply_chat_template (messages = messages , add_generation_prompt = True , template_str = template_str )
175
+ else :
176
+ user_prompt = tokenizer .apply_chat_template (messages = messages , add_generation_prompt = True )
161
177
input_tokens = tokenizer .encode (user_prompt )
162
178
generator .append_tokens (input_tokens )
163
179
0 commit comments