Skip to content

Commit 5e991fb

Browse files
sayanshaw24Sayan Shaw
andauthored
Add tools as chat template input (#951)
* add tools as chat template input * chore: trigger CI * chore: trigger CI * chore: trigger CI * add python test --------- Co-authored-by: Sayan Shaw <[email protected]>
1 parent 666d9aa commit 5e991fb

File tree

8 files changed

+56
-17
lines changed

8 files changed

+56
-17
lines changed

include/ortx_tokenizer.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,14 +213,15 @@ extError_t ORTX_API_CALL OrtxTokenId2DArrayGetItem(const OrtxTokenId2DArray* tok
213213
* @param tokenizer Pointer to an OrtxTokenizer used for template processing.
214214
* @param template_str Null-terminated string representing the chat template; can be null if tokenizer.json has one.
215215
* @param input Null-terminated string containing the input to be processed.
216+
* @param tools Null-terminated string containing the function tools.
216217
* @param output Pointer to an OrtxTensorResult that will be populated with the output strings,
217218
* if tokenize is true, the ids will be in the output as indexed 1.
218219
* @param add_generation_prompt Indicates whether to add a generation prompt to the output.
219220
* @param tokenize Indicates whether to tokenize the templated text to IDs
220221
* @return extError_t Returns an error code indicating success or the type of failure.
221222
*/
222223
extError_t ORTX_API_CALL OrtxApplyChatTemplate(const OrtxTokenizer* tokenizer, const char* template_str,
223-
const char* input, OrtxTensorResult** output,
224+
const char* input, const char* tools, OrtxTensorResult** output,
224225
bool add_generation_prompt, bool tokenize);
225226

226227
#ifdef __cplusplus

onnxruntime_extensions/pp_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ def tokenize(self, text, add_special_tokens = True):
5858
def detokenize(self, tokens):
5959
return batch_detokenize(self.tokenizer, [tokens])
6060

61-
def apply_chat_template(self, chat, template="", add_generation_prompt=True, tokenize=False):
61+
def apply_chat_template(self, chat, template="", tools="",add_generation_prompt=True, tokenize=False):
6262
result = _apply_chat_template(
63-
self.tokenizer, template, chat, add_generation_prompt, tokenize)
63+
self.tokenizer, template, chat, tools, add_generation_prompt, tokenize)
6464
return tensor_result_get_at(result, 1 if tokenize else 0)
6565

6666
def __del__(self):

pyop/py_c_api.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -229,12 +229,13 @@ void AddGlobalMethodsCApi(pybind11::module& m) {
229229

230230
m.def(
231231
"apply_chat_template",
232-
[](std::uintptr_t h, const std::string& template_str, const std::string& input, bool add_generation_prompt,
233-
bool tokenize) -> std::uintptr_t {
232+
[](std::uintptr_t h, const std::string& template_str, const std::string& input, const std::string& tools,
233+
bool add_generation_prompt, bool tokenize) -> std::uintptr_t {
234234
OrtxTokenizer* tokenizer = reinterpret_cast<OrtxTokenizer*>(h);
235235
OrtxTensorResult* result{};
236236
auto err = OrtxApplyChatTemplate(tokenizer, template_str.empty() ? nullptr : template_str.c_str(),
237-
input.c_str(), &result, add_generation_prompt, tokenize);
237+
input.c_str(), tools.empty() ? nullptr : tools.c_str(),
238+
&result, add_generation_prompt, tokenize);
238239
if (err != kOrtxOK) {
239240
throw std::runtime_error(std::string("Failed to apply chat template: ") + OrtxGetLastErrorMessage());
240241
}

shared/api/c_api_tokenizer.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,8 @@ extError_t ORTX_API_CALL OrtxDetokenizeCached(const OrtxTokenizer* tokenizer, Or
355355
}
356356

357357
extError_t ORTX_API_CALL OrtxApplyChatTemplate(const OrtxTokenizer* tokenizer, const char* template_str,
358-
const char* input, OrtxTensorResult** output, bool add_generation_prompt,
358+
const char* input, const char* tools,
359+
OrtxTensorResult** output, bool add_generation_prompt,
359360
bool tokenize) {
360361
if (tokenizer == nullptr && template_str == nullptr) {
361362
ReturnableStatus::last_error_message_ = "both tokenizer and template_str are null, no template to apply";
@@ -375,7 +376,7 @@ extError_t ORTX_API_CALL OrtxApplyChatTemplate(const OrtxTokenizer* tokenizer, c
375376

376377
std::string text;
377378
std::vector<extTokenId_t> ids_vec;
378-
status = token_ptr->ApplyChatTemplate(template_str, input, text, ids_vec, add_generation_prompt, tokenize);
379+
status = token_ptr->ApplyChatTemplate(template_str, input, tools, text, ids_vec, add_generation_prompt, tokenize);
379380
if (status.IsOk()) {
380381
auto result = std::make_unique<ort_extensions::TensorResult>();
381382
std::vector<std::unique_ptr<ortc::TensorBase>> tensors;

shared/api/chat_template.cc

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,8 @@ void TokenizerImpl::InitializeChatParameters(const char* template_str,
718718
}
719719

720720
// ApplyChatTemplate method to choose the template logic based on chat_template
721-
OrtxStatus TokenizerImpl::ApplyChatTemplate(const TokenizerImpl::MessageList& message_list, std::string& output,
721+
OrtxStatus TokenizerImpl::ApplyChatTemplate(const TokenizerImpl::MessageList& message_list,
722+
const char* tools, std::string& output,
722723
bool add_generation_prompt) const {
723724
// Note: The official chat template from this model's config file may not be supported.
724725
// However, we do not throw an error until checking model_to_template_map as the user
@@ -734,6 +735,20 @@ OrtxStatus TokenizerImpl::ApplyChatTemplate(const TokenizerImpl::MessageList& me
734735

735736
messages = message_list;
736737

738+
if (tools && *tools) {
739+
tool_calls = std::string(tools);
740+
if (!messages.empty()) {
741+
if (messages[0].find("tools") != messages[0].end()) {
742+
messages[0]["tools"] = tool_calls;
743+
tools_in_user_message = true;
744+
}
745+
if (messages[0].find("tool_calls ") != messages[0].end()) {
746+
messages[0]["tool_calls "] = tool_calls;
747+
tools_in_user_message = true;
748+
}
749+
}
750+
}
751+
737752
// Apply the corresponding chat template if it is supported
738753
if (chat_template == PHI4_CHAT_TEMPLATE) {
739754
return Phi4ChatTemplate(output, add_generation_prompt);
@@ -762,9 +777,9 @@ OrtxStatus TokenizerImpl::ApplyChatTemplate(const TokenizerImpl::MessageList& me
762777
return {};
763778
}
764779

765-
OrtxStatus TokenizerImpl::ApplyChatTemplate(const char* template_str, const char* message, std::string& output,
766-
std::vector<extTokenId_t>& ids_vec, bool add_generation_prompt,
767-
bool tokenize) const {
780+
OrtxStatus TokenizerImpl::ApplyChatTemplate(const char* template_str, const char* message, const char* tools,
781+
std::string& output, std::vector<extTokenId_t>& ids_vec,
782+
bool add_generation_prompt, bool tokenize) const {
768783
OrtxStatus status;
769784
std::string input_str = minja::normalize_newlines(message);
770785
auto activated_str = tok_config_->chat_template_.c_str();
@@ -783,7 +798,7 @@ OrtxStatus TokenizerImpl::ApplyChatTemplate(const char* template_str, const char
783798
return {kOrtxErrorInvalidArgument, "Invalid JSON format in chat message."};
784799
}
785800

786-
status = ApplyChatTemplate(message_list, output, add_generation_prompt);
801+
status = ApplyChatTemplate(message_list, tools, output, add_generation_prompt);
787802
} else {
788803
using json = nlohmann::ordered_json;
789804
std::string text;

shared/api/tokenizer_impl.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,12 @@ class TokenizerImpl : public OrtxObjectImpl {
5858
std::string chat_template;
5959
mutable MessageList messages;
6060

61+
mutable std::string tool_calls;
62+
6163
std::string bos_token;
6264
std::string eos_token;
6365
std::vector<std::string> custom_tools;
64-
bool tools_in_user_message;
66+
mutable bool tools_in_user_message;
6567
std::string strftime_now;
6668
std::string date_string;
6769
std::vector<std::string> builtin_tools;
@@ -81,7 +83,7 @@ class TokenizerImpl : public OrtxObjectImpl {
8183
OrtxStatus Id2Token(extTokenId_t id, std::string& token, TokenizerDecodingState** state) const;
8284
OrtxStatus GetDecoderPromptIds(size_t batch_size, const char* lang, const char* task, int no_timestamps,
8385
std::vector<std::vector<extTokenId_t>>& t_ids) const;
84-
OrtxStatus ApplyChatTemplate(const char* template_str, const char* message, std::string& output,
86+
OrtxStatus ApplyChatTemplate(const char* template_str, const char* message, const char* tools, std::string& output,
8587
std::vector<extTokenId_t>& ids_vec, bool add_generation_prompt, bool tokenize) const;
8688

8789
private:
@@ -95,7 +97,7 @@ class TokenizerImpl : public OrtxObjectImpl {
9597
const std::string& date_str = "26 Jul 2024",
9698
const std::vector<std::string>& builtin_tools_param = {});
9799

98-
OrtxStatus ApplyChatTemplate(const MessageList& messages, std::string& output,
100+
OrtxStatus ApplyChatTemplate(const MessageList& messages, const char* tools, std::string& output,
99101
bool add_generation_prompt) const;
100102

101103
using bpe_tokenizer_t = std::unique_ptr<JsonFastTokenizer>;

test/pp_api_test/test_tokenizer_chat.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ TEST(OrtxTokenizerTest, Phi4ChatTemplate) {
3737

3838
auto err = OrtxApplyChatTemplate(
3939
tokenizer.get(), nullptr,
40-
messages_json.c_str(), templated_text.ToBeAssigned(), true, false);
40+
messages_json.c_str(), nullptr, templated_text.ToBeAssigned(), true, false);
4141

4242
if (err != kOrtxOK) {
4343
std::cout << "Failed to apply chat template, stopping the test." << std::endl;

test/test_pp_api.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,25 @@ def test_phi4_chat_template(self):
331331
tokenizer = pp_api.Tokenizer(model_id)
332332
ortx_inputs = tokenizer.apply_chat_template(message_json)
333333
np.testing.assert_array_equal(ortx_inputs, inputs)
334+
335+
def test_chat_tools_input(self):
336+
model_id = util.get_test_data_file("data/models/phi-4")
337+
messages = [
338+
{"role": "system", "content": "You are a medieval knight and must provide explanations to modern people."},
339+
{"role": "user", "content": "How should I explain the Internet?"},
340+
]
341+
message_json = json.dumps(messages)
342+
tokenizer = pp_api.Tokenizer(model_id)
343+
344+
# Note: we simply test passing in a tools input to apply_chat_template here,
345+
# we do not compare with HF as they place the result of the function call in their output,
346+
# and we do not have the ability to call a function in-line within the C++ chat template backend.
347+
tool_calls = """[{"name": "fn1", "description": "fn details", "parameters": {"p1": {"description": "details", "type": "string"}}}, {"fn2": 2},{"fn3": 3}]"""
348+
349+
try:
350+
tokenizer.apply_chat_template(chat=message_json, tools=tool_calls)
351+
except Exception as e:
352+
assert False, f"Error while trying to pass in tools to chat template: {e}"
334353

335354
def test_qwen2_5_vl_chat_template(self):
336355
model_id = "Qwen/Qwen2.5-VL-72B-Instruct"

0 commit comments

Comments
 (0)