Skip to content

Commit 71435cf

Browse files
author
ochafik
committed
Inject date_string in llama 3.x + fix for functionary v2
#12729
1 parent 7a84777 commit 71435cf

File tree

3 files changed

+127
-63
lines changed

3 files changed

+127
-63
lines changed

common/chat.cpp

Lines changed: 77 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@
66

77
#include <optional>
88

9+
static std::string format_time(const std::chrono::system_clock::time_point & now, const std::string & format) {
10+
auto time = std::chrono::system_clock::to_time_t(now);
11+
auto local_time = *std::localtime(&time);
12+
std::ostringstream ss;
13+
ss << std::put_time(&local_time, format.c_str());
14+
auto res = ss.str();
15+
return res;
16+
}
17+
918
typedef minja::chat_template common_chat_template;
1019

1120
struct common_chat_templates {
@@ -24,6 +33,7 @@ struct templates_params {
2433
std::string grammar;
2534
bool add_generation_prompt = true;
2635
bool extract_reasoning = true;
36+
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
2737
};
2838

2939
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
@@ -937,72 +947,75 @@ static void expect_tool_parameters(const std::string & name, const json & parame
937947
}
938948
}
939949

940-
static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
950+
static common_chat_params common_chat_params_init_llama_3_x(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
941951
auto builtin_tools = json::array();
942952
common_chat_params data;
943-
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
944-
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
945-
std::vector<std::string> tool_rules;
953+
if (!inputs.tools.is_null()) {
954+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
955+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
956+
std::vector<std::string> tool_rules;
946957

947-
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
948-
if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
949-
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
950-
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
951-
expect_tool_parameters(name, parameters, {"query"});
952-
} else if (name == "python" || name == "code_interpreter") {
953-
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
954-
expect_tool_parameters(name, parameters, {"code"});
955-
} else {
956-
return false;
957-
}
958+
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
959+
if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
960+
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
961+
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
962+
expect_tool_parameters(name, parameters, {"query"});
963+
} else if (name == "python" || name == "code_interpreter") {
964+
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
965+
expect_tool_parameters(name, parameters, {"code"});
966+
} else {
967+
return false;
968+
}
958969

959-
std::vector<std::string> kvs;
960-
for (const auto & [key, value] : parameters.at("properties").items()) {
961-
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
962-
}
970+
std::vector<std::string> kvs;
971+
for (const auto & [key, value] : parameters.at("properties").items()) {
972+
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
973+
}
963974

964-
tool_rules.push_back(
965-
builder.add_rule(
966-
name + "-call",
967-
"\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
968-
builtin_tools.push_back(name);
975+
tool_rules.push_back(
976+
builder.add_rule(
977+
name + "-call",
978+
"\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
979+
builtin_tools.push_back(name);
969980

970-
return true;
971-
};
981+
return true;
982+
};
972983

973-
foreach_function(inputs.tools, [&](const json & tool) {
974-
const auto & function = tool.at("function");
975-
std::string name = function.at("name");
976-
auto parameters = function.at("parameters");
977-
builder.resolve_refs(parameters);
984+
foreach_function(inputs.tools, [&](const json & tool) {
985+
const auto & function = tool.at("function");
986+
std::string name = function.at("name");
987+
auto parameters = function.at("parameters");
988+
builder.resolve_refs(parameters);
978989

979-
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
980-
if (allow_python_tag_builtin_tools) {
981-
handle_builtin_tool(name, parameters);
990+
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
991+
if (allow_python_tag_builtin_tools) {
992+
handle_builtin_tool(name, parameters);
993+
}
994+
tool_rules.push_back(
995+
builder.add_rule(
996+
name + "-call",
997+
"\"{\" space "
998+
"( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
999+
" \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
1000+
" \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
1001+
"\"}\" space"));
1002+
});
1003+
// Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
1004+
data.grammar_triggers.push_back({
1005+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
1006+
"\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*",
1007+
});
1008+
if (!builtin_tools.empty()) {
1009+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
1010+
data.preserved_tokens.push_back("<|python_tag|>");
9821011
}
983-
tool_rules.push_back(
984-
builder.add_rule(
985-
name + "-call",
986-
"\"{\" space "
987-
"( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
988-
" \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
989-
" \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
990-
"\"}\" space"));
991-
});
992-
// Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
993-
data.grammar_triggers.push_back({
994-
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
995-
"\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*",
1012+
// Allow a few empty lines on top of the usual constrained json schema space rule.
1013+
builder.add_rule("root", string_join(tool_rules, " | "));
1014+
data.additional_stops.push_back("<|eom_id|>");
9961015
});
997-
if (!builtin_tools.empty()) {
998-
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
999-
data.preserved_tokens.push_back("<|python_tag|>");
1000-
}
1001-
// Allow a few empty lines on top of the usual constrained json schema space rule.
1002-
builder.add_rule("root", string_join(tool_rules, " | "));
1003-
});
1004-
data.additional_stops.push_back("<|eom_id|>");
1016+
}
10051017
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
1018+
{"date_string", format_time(inputs.now, "%d %b %Y")},
10061019
{"tools_in_user_message", false},
10071020
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
10081021
});
@@ -1148,7 +1161,7 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
11481161
LOG_DBG("%s\n", __func__);
11491162
common_chat_params data;
11501163
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
1151-
{"datetime", "Jan 29 2025 13:00:00 GMT"},
1164+
{"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")},
11521165
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
11531166
});
11541167
if (inputs.tools.is_array() && !inputs.tools.empty()) {
@@ -1591,6 +1604,7 @@ static common_chat_params common_chat_templates_apply_jinja(
15911604
params.extract_reasoning = inputs.extract_reasoning;
15921605
params.tool_choice = inputs.tool_choice;
15931606
params.grammar = inputs.grammar;
1607+
params.now = inputs.now;
15941608
if (!inputs.json_schema.empty()) {
15951609
params.json_schema = json::parse(inputs.json_schema);
15961610
}
@@ -1621,7 +1635,7 @@ static common_chat_params common_chat_templates_apply_jinja(
16211635
return common_chat_params_init_command_r7b(tmpl, params);
16221636
}
16231637

1624-
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
1638+
// Hermes 2/3 Pro, Qwen 2.5 Instruct
16251639
if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) {
16261640
return common_chat_params_init_hermes_2_pro(tmpl, params);
16271641
}
@@ -1642,6 +1656,12 @@ static common_chat_params common_chat_templates_apply_jinja(
16421656
return common_chat_params_init_firefunction_v2(tmpl, params);
16431657
}
16441658

1659+
// Llama 3.1, 3.2, 3.3 (also requires date_string so using it even w/o tools)
1660+
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
1661+
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
1662+
return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools);
1663+
}
1664+
16451665
// Plain handler (no tools)
16461666
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
16471667
return common_chat_params_init_without_tools(tmpl, params);
@@ -1653,12 +1673,6 @@ static common_chat_params common_chat_templates_apply_jinja(
16531673
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
16541674
}
16551675

1656-
// Llama 3.1, 3.2, 3.3 (w/ tools)
1657-
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
1658-
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
1659-
return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools);
1660-
}
1661-
16621676
// Mistral Nemo (w/ tools)
16631677
if (src.find("[TOOL_CALLS]") != std::string::npos) {
16641678
return common_chat_params_init_mistral_nemo(tmpl, params);

common/chat.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ struct common_chat_templates_inputs {
7171
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
7272
bool parallel_tool_calls = false;
7373
bool extract_reasoning = true;
74+
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
7475
};
7576

7677
struct common_chat_params {
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#!/usr/bin/env python
2+
import pytest
3+
4+
# ensure grandparent path is in sys.path
5+
from pathlib import Path
6+
import sys
7+
8+
from unit.test_tool_call import TEST_TOOL
9+
path = Path(__file__).resolve().parents[1]
10+
sys.path.insert(0, str(path))
11+
12+
import datetime
13+
from utils import *
14+
15+
server: ServerProcess
16+
17+
TIMEOUT_SERVER_START = 15*60
18+
19+
@pytest.fixture(autouse=True)
20+
def create_server():
21+
global server
22+
server = ServerPreset.tinyllama2()
23+
server.model_alias = "tinyllama-2"
24+
server.server_port = 8081
25+
server.n_slots = 1
26+
27+
28+
@pytest.mark.parametrize("tools", [None, [], [TEST_TOOL]])
29+
@pytest.mark.parametrize("template_name,format", [
30+
("meta-llama-Llama-3.3-70B-Instruct", "%d %b %Y"),
31+
("fireworks-ai-llama-3-firefunction-v2", "%b %d %Y"),
32+
])
33+
def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]):
34+
global server
35+
server.jinja = True
36+
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
37+
server.start(timeout_seconds=TIMEOUT_SERVER_START)
38+
39+
res = server.make_request("POST", "/apply-template", data={
40+
"messages": [
41+
{"role": "user", "content": "What is today?"},
42+
],
43+
"tools": tools,
44+
})
45+
assert res.status_code == 200
46+
prompt = res.body["prompt"]
47+
48+
today_str = datetime.date.today().strftime(format)
49+
assert today_str in prompt, f"Expected today's date ({today_str}) in content ({prompt})"

0 commit comments

Comments
 (0)