Skip to content

Commit 072307c

Browse files
committed
Merge branch 'n_b6110' into crokeso
2 parents 2034127 + 9aba5d1 commit 072307c

File tree

15 files changed

+600
-64
lines changed

15 files changed

+600
-64
lines changed

common/chat-parser.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,15 @@ bool common_chat_msg_parser::add_tool_call(const std::string & name, const std::
5555
bool common_chat_msg_parser::add_tool_call(const json & tool_call) {
5656
std::string name = tool_call.contains("name") ? tool_call.at("name") : "";
5757
std::string id = tool_call.contains("id") ? tool_call.at("id") : "";
58-
std::string arguments = tool_call.contains("arguments") ? tool_call.at("arguments") : "";
58+
std::string arguments = "";
59+
if (tool_call.contains("arguments")) {
60+
if (tool_call.at("arguments").is_object()) {
61+
arguments = tool_call.at("arguments").dump();
62+
} else {
63+
arguments = tool_call.at("arguments");
64+
}
65+
}
66+
5967
return add_tool_call(name, id, arguments);
6068
}
6169

common/chat.cpp

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,7 @@ const char * common_chat_format_name(common_chat_format format) {
606606
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
607607
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
608608
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
609+
case COMMON_CHAT_FORMAT_GRANITE: return "Granite";
609610
default:
610611
throw std::runtime_error("Unknown chat format");
611612
}
@@ -616,6 +617,7 @@ const char * common_reasoning_format_name(common_reasoning_format format) {
616617
case COMMON_REASONING_FORMAT_NONE: return "none";
617618
case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek";
618619
case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy";
620+
case COMMON_REASONING_FORMAT_GRANITE: return "granite";
619621
default:
620622
throw std::runtime_error("Unknown reasoning format");
621623
}
@@ -1712,6 +1714,124 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
17121714
builder.add_content(builder.consume_rest());
17131715
}
17141716

1717+
static common_chat_params common_chat_params_init_granite(const common_chat_template & tmpl, const struct templates_params & inputs) {
1718+
common_chat_params data;
1719+
1720+
// Pass thinking context for Granite template
1721+
json additional_context = {
1722+
{"thinking", inputs.enable_thinking},
1723+
};
1724+
1725+
data.prompt = apply(tmpl, inputs, /* messages_override= */ std::nullopt, /* tools_override= */ std::nullopt, additional_context);
1726+
data.format = COMMON_CHAT_FORMAT_GRANITE;
1727+
1728+
if (string_ends_with(data.prompt, "<think>\n") || string_ends_with(data.prompt, "<think>")) {
1729+
if (!inputs.enable_thinking) {
1730+
data.prompt += "</think>";
1731+
} else {
1732+
data.thinking_forced_open = true;
1733+
}
1734+
}
1735+
1736+
if (!inputs.tools.is_null()) {
1737+
// Granite uses <|tool_call|> followed by JSON list
1738+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
1739+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1740+
std::vector<std::string> tool_rules;
1741+
foreach_function(inputs.tools, [&](const json & tool) {
1742+
const auto & function = tool.at("function");
1743+
std::string name = function.at("name");
1744+
auto parameters = function.at("parameters");
1745+
builder.resolve_refs(parameters);
1746+
tool_rules.push_back(builder.add_rule(name + "-call", builder.add_schema(name +
1747+
"-args", {
1748+
{"type", "object"},
1749+
{"properties", {
1750+
{"name", {{"const", name}}},
1751+
{"arguments", parameters},
1752+
}},
1753+
{"required", json::array({"name", "arguments"})},
1754+
})));
1755+
});
1756+
1757+
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | "));
1758+
auto tool_list = builder.add_rule("tool_list", "\"[\" space " + tool_call + " (\",\" space " + tool_call + ")* space \"]\"");
1759+
1760+
if (data.thinking_forced_open) {
1761+
builder.add_rule("root", "\"</think>\" space \"<response>\" space [^<]* \"</response>\" space \"<|tool_call|>\" space " + tool_list);
1762+
} else {
1763+
builder.add_rule("root", "\"<|tool_call|>\" space " + tool_list);
1764+
}
1765+
1766+
data.grammar_triggers.push_back({
1767+
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
1768+
"<|tool_call|>"
1769+
});
1770+
1771+
data.preserved_tokens = {
1772+
"<think>",
1773+
"</think>",
1774+
"<response>",
1775+
"</response>",
1776+
"<|tool_call|>",
1777+
};
1778+
});
1779+
} else {
1780+
// Handle thinking tags for non-tool responses
1781+
if (data.thinking_forced_open && inputs.enable_thinking) {
1782+
data.grammar_lazy = false;
1783+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1784+
builder.add_rule("root", "\"</think>\" space \"<response>\" space .* \"</response>\" space");
1785+
});
1786+
data.preserved_tokens = {
1787+
"<think>",
1788+
"</think>",
1789+
"<response>",
1790+
"</response>",
1791+
};
1792+
}
1793+
}
1794+
1795+
return data;
1796+
}
1797+
1798+
static void common_chat_parse_granite(common_chat_msg_parser & builder) {
1799+
// Parse thinking tags
1800+
builder.try_parse_reasoning("<think>", "</think>");
1801+
1802+
// Parse response tags using regex
1803+
static const common_regex response_regex("<response>([\\s\\S]*?)</response>");
1804+
if (auto res = builder.try_find_regex(response_regex)) {
1805+
// Extract the content between the tags (capture group 1)
1806+
auto content = builder.str(res->groups[1]);
1807+
builder.add_content(content);
1808+
builder.move_to(res->groups[0].end);
1809+
}
1810+
1811+
if (!builder.syntax().parse_tool_calls) {
1812+
builder.add_content(builder.consume_rest());
1813+
return;
1814+
}
1815+
1816+
// Look for tool calls
1817+
static const common_regex tool_call_regex(regex_escape("<|tool_call|>"));
1818+
if (auto res = builder.try_find_regex(tool_call_regex)) {
1819+
builder.move_to(res->groups[0].end);
1820+
1821+
// Expect JSON array of tool calls
1822+
auto tool_calls_data = builder.consume_json();
1823+
if (tool_calls_data.json.is_array()) {
1824+
if (!builder.add_tool_calls(tool_calls_data.json)) {
1825+
builder.add_content("<|tool_call|>" + tool_calls_data.json.dump());
1826+
}
1827+
} else {
1828+
builder.add_content("<|tool_call|>" + tool_calls_data.json.dump());
1829+
}
1830+
} else {
1831+
builder.add_content(builder.consume_rest());
1832+
}
1833+
}
1834+
17151835
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
17161836
common_chat_params data;
17171837
data.prompt = apply(tmpl, inputs);
@@ -1783,6 +1903,11 @@ static common_chat_params common_chat_templates_apply_jinja(
17831903
return common_chat_params_init_command_r7b(tmpl, params);
17841904
}
17851905

1906+
// Granite (IBM) - detects thinking / tools support
1907+
if (src.find("elif thinking") != std::string::npos && src.find("<|tool_call|>") != std::string::npos) {
1908+
return common_chat_params_init_granite(tmpl, params);
1909+
}
1910+
17861911
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
17871912
if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) {
17881913
return common_chat_params_init_hermes_2_pro(tmpl, params);
@@ -1838,6 +1963,7 @@ static common_chat_params common_chat_templates_apply_legacy(
18381963
int alloc_size = 0;
18391964
std::vector<llama_chat_message> chat;
18401965
std::vector<std::string> contents;
1966+
18411967
for (const auto & msg : inputs.messages) {
18421968
auto content = msg.content;
18431969
for (const auto & part : msg.content_parts) {
@@ -1939,6 +2065,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
19392065
case COMMON_CHAT_FORMAT_COMMAND_R7B:
19402066
common_chat_parse_command_r7b(builder);
19412067
break;
2068+
case COMMON_CHAT_FORMAT_GRANITE:
2069+
common_chat_parse_granite(builder);
2070+
break;
19422071
default:
19432072
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
19442073
}

common/chat.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ enum common_chat_format {
109109
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
110110
COMMON_CHAT_FORMAT_HERMES_2_PRO,
111111
COMMON_CHAT_FORMAT_COMMAND_R7B,
112+
COMMON_CHAT_FORMAT_GRANITE,
112113

113114
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
114115
};

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ enum common_reasoning_format {
234234
COMMON_REASONING_FORMAT_NONE,
235235
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
236236
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
237+
COMMON_REASONING_FORMAT_GRANITE, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
237238
};
238239

239240
struct common_params {

ggml/src/ggml-backend.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,6 +1077,11 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
10771077
}
10781078
}
10791079
}
1080+
// if the node is still unassigned, assign it to the first backend that supports it
1081+
for (int b = 0; b < sched->n_backends && *cur_backend_id == -1; b++) {
1082+
ggml_backend_sched_set_if_supported(sched, node, b, cur_backend_id);
1083+
}
1084+
GGML_ASSERT(*cur_backend_id != -1);
10801085
}
10811086

10821087
// pass 5: split graph, find tensors that need to be copied
@@ -1104,7 +1109,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
11041109

11051110
const int node_backend_id = tensor_backend_id(node);
11061111

1107-
assert(node_backend_id != -1); // all nodes should be assigned by now, this can happen if there is no CPU fallback
1112+
GGML_ASSERT(node_backend_id != -1); // all nodes should be assigned by now, this can happen if there is no CPU fallback
11081113

11091114
// check if we should start a new split based on the sources of the current node
11101115
bool need_new_split = false;
@@ -1162,7 +1167,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
11621167

11631168
size_t src_id = hash_id(src);
11641169
const int src_backend_id = sched->hv_tensor_backend_ids[src_id];
1165-
assert(src_backend_id != -1); // all inputs should be assigned by now
1170+
GGML_ASSERT(src_backend_id != -1); // all inputs should be assigned by now
11661171

11671172
if (src->flags & GGML_TENSOR_FLAG_INPUT && sched->n_copies > 1) {
11681173
if (tensor_id_copy(src_id, src_backend_id, 0) == NULL) {

ggml/src/ggml-cpu/ggml-cpu-traits.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ extra_buffer_type::~extra_buffer_type() {}
1010
} // namespace ggml::cpu
1111

1212
bool ggml_cpu_extra_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) {
13-
for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
13+
for (auto extra : ggml_backend_cpu_get_extra_buffer_types()) {
1414
if (extra && extra->context) {
1515
auto buf_extra = (ggml::cpu::extra_buffer_type *) extra->context;
1616
auto tensor_traits = buf_extra->get_tensor_traits(op);
@@ -23,7 +23,7 @@ bool ggml_cpu_extra_compute_forward(struct ggml_compute_params * params, struct
2323
}
2424

2525
bool ggml_cpu_extra_work_size(int n_threads, const struct ggml_tensor * op, size_t * size) {
26-
for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
26+
for (auto extra : ggml_backend_cpu_get_extra_buffer_types()) {
2727
if (extra && extra->context) {
2828
auto buf_extra = (ggml::cpu::extra_buffer_type *) extra->context;
2929
auto tensor_traits = buf_extra->get_tensor_traits(op);

ggml/src/ggml-cpu/ggml-cpu-traits.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,6 @@ class extra_buffer_type {
3333
} // namespace ggml::cpu
3434

3535
// implemented in ggml-cpu.cpp.
36-
std::vector<ggml_backend_buffer_type_t> & ggml_backend_cpu_get_extra_buffers_type();
36+
std::vector<ggml_backend_buffer_type_t> & ggml_backend_cpu_get_extra_buffer_types();
3737

3838
#endif

ggml/src/ggml-cpu/ggml-cpu.cpp

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040

4141
// ggml-backend interface
4242

43-
std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type() {
43+
std::vector<ggml_backend_buffer_type_t> & ggml_backend_cpu_get_extra_buffer_types() {
4444
static std::vector<ggml_backend_buffer_type_t> bufts = []() {
4545
std::vector<ggml_backend_buffer_type_t> bufts;
4646

@@ -62,23 +62,27 @@ std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type
6262
}
6363
#endif
6464

65-
bufts.push_back(NULL);
66-
6765
return bufts;
6866
}();
6967

7068
return bufts;
7169
}
7270

7371
static ggml_backend_buffer_type_t * ggml_backend_cpu_device_get_extra_buffers_type(ggml_backend_dev_t device) {
74-
return ggml_backend_cpu_get_extra_buffers_type().data();
72+
static std::vector<ggml_backend_buffer_type_t> extra_bufts = [] {
73+
std::vector<ggml_backend_buffer_type_t> bufts = ggml_backend_cpu_get_extra_buffer_types();
74+
bufts.push_back(nullptr);
75+
return bufts;
76+
}();
77+
78+
return extra_bufts.data();
7579

7680
GGML_UNUSED(device);
7781
}
7882

7983
static bool ggml_backend_cpu_is_extra_buffer_type(ggml_backend_buffer_type_t buft) {
80-
for (auto * extra : ggml_backend_cpu_get_extra_buffers_type()) {
81-
if (extra && extra == buft) {
84+
for (auto * extra : ggml_backend_cpu_get_extra_buffer_types()) {
85+
if (extra == buft) {
8286
return true;
8387
}
8488
}
@@ -402,20 +406,13 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
402406
return true;
403407
}
404408

405-
// extra_buffer_op?
406-
for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
407-
if (extra) {
408-
auto buf_extra = (ggml::cpu::extra_buffer_type*) extra->context;
409-
if (buf_extra && buf_extra->supports_op(dev, op)) {
410-
return true;
411-
}
412-
}
413-
}
414-
415-
// the other case need host buffer.
416-
for (int i = 0; i < GGML_MAX_SRC; i++) {
417-
if (op->src[i] && op->src[i]->buffer && !ggml_backend_buft_is_host(op->src[i]->buffer->buft)) {
418-
return false;
409+
// check extra buffer types
410+
// note: only the first sources are checked for extra buffer types to reduce overhead, increase if necessary
411+
for (int i = 0; i < 4; i++) {
412+
if (op->src[i] && op->src[i]->buffer &&
413+
ggml_backend_cpu_is_extra_buffer_type(op->src[i]->buffer->buft)) {
414+
auto * buf_extra = (ggml::cpu::extra_buffer_type *) op->src[i]->buffer->buft->context;
415+
return buf_extra->supports_op(dev, op);
419416
}
420417
}
421418

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2+
3+
//------------------------------------------------------------------------------
4+
// add_id
5+
//------------------------------------------------------------------------------
6+
kernel void kernel_add_id(
7+
global char * src0,
8+
ulong offset0,
9+
global char * src1,
10+
ulong offset1,
11+
global char * src2,
12+
ulong offset2,
13+
global char * dst,
14+
ulong offsetd,
15+
ulong nb01,
16+
ulong nb02,
17+
ulong nb11,
18+
ulong nb21,
19+
int ne0,
20+
int ne1
21+
) {
22+
src0 = (global char*)((global char*)src0 + offset0);
23+
src1 = (global char*)((global char*)src1 + offset1);
24+
src2 = (global char*)((global char*)src2 + offset2);
25+
dst = (global char*)((global char*)dst + offsetd);
26+
27+
int i1 = get_group_id(0);
28+
int i2 = get_group_id(1);
29+
30+
const int i11 = *((global const int *) (src2 + i1*sizeof(int) + i2*nb21));
31+
32+
const size_t nb1 = ne0 * sizeof(float);
33+
const size_t nb2 = ne1 * nb1;
34+
35+
global float * dst_row = (global float *)((global char *)dst + i1*nb1 + i2*nb2);
36+
global float * src0_row = (global float *)((global char *)src0 + i1*nb01 + i2*nb02);
37+
global float * src1_row = (global float *)((global char *)src1 + i11*nb11);
38+
39+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
40+
dst_row[i0] = src0_row[i0] + src1_row[i0];
41+
}
42+
}

0 commit comments

Comments
 (0)