Skip to content

Commit 90789cd

Browse files
author
ochafik
committed
Inject date_string in llama 3.x + test it & functionary v2
ggml-org#12729
1 parent a604b2d commit 90789cd

File tree

3 files changed

+128
-63
lines changed

3 files changed

+128
-63
lines changed

common/chat.cpp

Lines changed: 78 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,22 @@
1010

1111
#include <cstdio>
1212
#include <exception>
13+
#include <iostream>
1314
#include <optional>
1415
#include <stdexcept>
1516
#include <string>
1617
#include <vector>
1718

1819

20+
static std::string format_time(const std::chrono::system_clock::time_point & now, const std::string & format) {
21+
auto time = std::chrono::system_clock::to_time_t(now);
22+
auto local_time = *std::localtime(&time);
23+
std::ostringstream ss;
24+
ss << std::put_time(&local_time, format.c_str());
25+
auto res = ss.str();
26+
return res;
27+
}
28+
1929
static std::string string_diff(const std::string & last, const std::string & current) {
2030
if (last.empty()) {
2131
return current;
@@ -123,6 +133,7 @@ struct templates_params {
123133
bool stream;
124134
std::string grammar;
125135
bool add_generation_prompt = true;
136+
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
126137
};
127138

128139
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
@@ -1017,72 +1028,75 @@ static void expect_tool_parameters(const std::string & name, const json & parame
10171028
}
10181029
}
10191030

1020-
static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
1031+
static common_chat_params common_chat_params_init_llama_3_x(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
10211032
auto builtin_tools = json::array();
10221033
common_chat_params data;
1023-
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
1024-
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1025-
std::vector<std::string> tool_rules;
1034+
if (!inputs.tools.is_null()) {
1035+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
1036+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1037+
std::vector<std::string> tool_rules;
10261038

1027-
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
1028-
if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
1029-
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
1030-
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
1031-
expect_tool_parameters(name, parameters, {"query"});
1032-
} else if (name == "python" || name == "code_interpreter") {
1033-
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
1034-
expect_tool_parameters(name, parameters, {"code"});
1035-
} else {
1036-
return false;
1037-
}
1039+
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
1040+
if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
1041+
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
1042+
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
1043+
expect_tool_parameters(name, parameters, {"query"});
1044+
} else if (name == "python" || name == "code_interpreter") {
1045+
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
1046+
expect_tool_parameters(name, parameters, {"code"});
1047+
} else {
1048+
return false;
1049+
}
10381050

1039-
std::vector<std::string> kvs;
1040-
for (const auto & [key, value] : parameters.at("properties").items()) {
1041-
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
1042-
}
1051+
std::vector<std::string> kvs;
1052+
for (const auto & [key, value] : parameters.at("properties").items()) {
1053+
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
1054+
}
10431055

1044-
tool_rules.push_back(
1045-
builder.add_rule(
1046-
name + "-call",
1047-
"\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
1048-
builtin_tools.push_back(name);
1056+
tool_rules.push_back(
1057+
builder.add_rule(
1058+
name + "-call",
1059+
"\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
1060+
builtin_tools.push_back(name);
10491061

1050-
return true;
1051-
};
1062+
return true;
1063+
};
10521064

1053-
foreach_function(inputs.tools, [&](const json & tool) {
1054-
const auto & function = tool.at("function");
1055-
std::string name = function.at("name");
1056-
auto parameters = function.at("parameters");
1057-
builder.resolve_refs(parameters);
1065+
foreach_function(inputs.tools, [&](const json & tool) {
1066+
const auto & function = tool.at("function");
1067+
std::string name = function.at("name");
1068+
auto parameters = function.at("parameters");
1069+
builder.resolve_refs(parameters);
10581070

1059-
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
1060-
if (allow_python_tag_builtin_tools) {
1061-
handle_builtin_tool(name, parameters);
1071+
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
1072+
if (allow_python_tag_builtin_tools) {
1073+
handle_builtin_tool(name, parameters);
1074+
}
1075+
tool_rules.push_back(
1076+
builder.add_rule(
1077+
name + "-call",
1078+
"\"{\" space "
1079+
"( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
1080+
" \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
1081+
" \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
1082+
"\"}\" space"));
1083+
});
1084+
// Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
1085+
data.grammar_triggers.push_back({
1086+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
1087+
"(\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\")[\\s\\S]*", // + name + "\"[\\s\\S]*",
1088+
});
1089+
if (!builtin_tools.empty()) {
1090+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
1091+
data.preserved_tokens.push_back("<|python_tag|>");
10621092
}
1063-
tool_rules.push_back(
1064-
builder.add_rule(
1065-
name + "-call",
1066-
"\"{\" space "
1067-
"( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
1068-
" \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
1069-
" \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
1070-
"\"}\" space"));
1071-
});
1072-
// Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
1073-
data.grammar_triggers.push_back({
1074-
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
1075-
"(\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\")[\\s\\S]*", // + name + "\"[\\s\\S]*",
1093+
// Allow a few empty lines on top of the usual constrained json schema space rule.
1094+
builder.add_rule("root", string_join(tool_rules, " | "));
10761095
});
1077-
if (!builtin_tools.empty()) {
1078-
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
1079-
data.preserved_tokens.push_back("<|python_tag|>");
1080-
}
1081-
// Allow a few empty lines on top of the usual constrained json schema space rule.
1082-
builder.add_rule("root", string_join(tool_rules, " | "));
1083-
});
1084-
data.additional_stops.push_back("<|eom_id|>");
1096+
data.additional_stops.push_back("<|eom_id|>");
1097+
}
10851098
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
1099+
{"date_string", format_time(inputs.now, "%d %b %Y")},
10861100
{"tools_in_user_message", false},
10871101
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
10881102
});
@@ -1234,7 +1248,7 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
12341248
LOG_DBG("%s\n", __func__);
12351249
common_chat_params data;
12361250
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
1237-
{"datetime", "Jan 29 2025 13:00:00 GMT"},
1251+
{"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")},
12381252
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
12391253
});
12401254
if (inputs.tools.is_array() && !inputs.tools.empty()) {
@@ -1648,6 +1662,7 @@ static common_chat_params common_chat_templates_apply_jinja(
16481662
params.add_generation_prompt = inputs.add_generation_prompt;
16491663
params.tool_choice = inputs.tool_choice;
16501664
params.grammar = inputs.grammar;
1665+
params.now = inputs.now;
16511666
if (!inputs.json_schema.empty()) {
16521667
params.json_schema = json::parse(inputs.json_schema);
16531668
}
@@ -1678,7 +1693,7 @@ static common_chat_params common_chat_templates_apply_jinja(
16781693
return common_chat_params_init_command_r7b(tmpl, params);
16791694
}
16801695

1681-
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
1696+
// Hermes 2/3 Pro, Qwen 2.5 Instruct
16821697
if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) {
16831698
return common_chat_params_init_hermes_2_pro(tmpl, params);
16841699
}
@@ -1699,6 +1714,12 @@ static common_chat_params common_chat_templates_apply_jinja(
16991714
return common_chat_params_init_firefunction_v2(tmpl, params);
17001715
}
17011716

1717+
// Llama 3.1, 3.2, 3.3 (also requires date_string so using it even w/o tools)
1718+
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
1719+
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
1720+
return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools);
1721+
}
1722+
17021723
// Plain handler (no tools)
17031724
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
17041725
return common_chat_params_init_without_tools(tmpl, params);
@@ -1710,12 +1731,6 @@ static common_chat_params common_chat_templates_apply_jinja(
17101731
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
17111732
}
17121733

1713-
// Llama 3.1, 3.2, 3.3 (w/ tools)
1714-
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
1715-
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
1716-
return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools);
1717-
}
1718-
17191734
// Mistral Nemo (w/ tools)
17201735
if (src.find("[TOOL_CALLS]") != std::string::npos) {
17211736
return common_chat_params_init_mistral_nemo(tmpl, params);

common/chat.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ struct common_chat_templates_inputs {
122122
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
123123
bool parallel_tool_calls = false;
124124
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
125+
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
125126
};
126127

127128
struct common_chat_params {
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#!/usr/bin/env python
2+
import pytest
3+
4+
# ensure grandparent path is in sys.path
5+
from pathlib import Path
6+
import sys
7+
8+
from unit.test_tool_call import TEST_TOOL
9+
path = Path(__file__).resolve().parents[1]
10+
sys.path.insert(0, str(path))
11+
12+
import datetime
13+
from utils import *
14+
15+
server: ServerProcess
16+
17+
TIMEOUT_SERVER_START = 15*60
18+
19+
@pytest.fixture(autouse=True)
20+
def create_server():
21+
global server
22+
server = ServerPreset.tinyllama2()
23+
server.model_alias = "tinyllama-2"
24+
server.server_port = 8081
25+
server.n_slots = 1
26+
27+
28+
@pytest.mark.parametrize("tools", [None, [], [TEST_TOOL]])
29+
@pytest.mark.parametrize("template_name,format", [
30+
("meta-llama-Llama-3.3-70B-Instruct", "%d %b %Y"),
31+
("fireworks-ai-llama-3-firefunction-v2", "%b %d %Y"),
32+
])
33+
def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]):
34+
global server
35+
server.jinja = True
36+
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
37+
server.start(timeout_seconds=TIMEOUT_SERVER_START)
38+
39+
res = server.make_request("POST", "/apply-template", data={
40+
"messages": [
41+
{"role": "user", "content": "What is today?"},
42+
],
43+
"tools": tools,
44+
})
45+
assert res.status_code == 200
46+
prompt = res.body["prompt"]
47+
48+
today_str = datetime.date.today().strftime(format)
49+
assert today_str in prompt, f"Expected today's date ({today_str}) in content ({prompt})"

0 commit comments

Comments
 (0)