Skip to content

Commit da39c79

Browse files
authored
Merge branch 'ggml-org:master' into glm45
2 parents 2232baa + f738989 commit da39c79

File tree

5 files changed

+101
-32
lines changed

5 files changed

+101
-32
lines changed

common/chat.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1646,7 +1646,7 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
16461646
"|<function name=\"([^\"]+)\">" // match 5 (function name again)
16471647
);
16481648

1649-
if (auto res = builder.try_find_regex(open_regex)) {
1649+
while (auto res = builder.try_find_regex(open_regex)) {
16501650
const auto & block_start = res->groups[1];
16511651
std::string block_end = block_start.empty() ? "" : "```";
16521652

@@ -1668,7 +1668,6 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
16681668
builder.consume_literal(block_end);
16691669
builder.consume_spaces();
16701670
}
1671-
builder.add_content(builder.consume_rest());
16721671
} else {
16731672
throw common_chat_msg_partial_exception("failed to parse tool call");
16741673
}
@@ -1693,11 +1692,10 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
16931692
builder.consume_spaces();
16941693
}
16951694
}
1696-
builder.add_content(builder.consume_rest());
16971695
}
1698-
} else {
1699-
builder.add_content(builder.consume_rest());
17001696
}
1697+
1698+
builder.add_content(builder.consume_rest());
17011699
}
17021700

17031701
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2106,12 +2106,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
21062106
s_mmq_wg_denoms = { 32, 64, 1 };
21072107

21082108
// spec constants and tile sizes for quant matmul (Qi_K)
2109-
l_warptile_mmq_k = { 256, 64, 128, 64, 1 };
2110-
m_warptile_mmq_k = { 256, 32, 64, 64, 0 };
2111-
s_warptile_mmq_k = { 256, 32, 32, 128, 0 };
2112-
l_mmq_wg_denoms_k = { 64, 128, 1 };
2113-
m_mmq_wg_denoms_k = { 32, 64, 1 };
2114-
s_mmq_wg_denoms_k = { 32, 32, 1 };
2109+
l_warptile_mmq_k = { 256, 128, 256, 64, 1 };
2110+
m_warptile_mmq_k = { 256, 128, 128, 64, 1 };
2111+
s_warptile_mmq_k = { 256, 32, 64, 128, 0 };
2112+
l_mmq_wg_denoms_k = { 128, 256, 1 };
2113+
m_mmq_wg_denoms_k = { 128, 128, 1 };
2114+
s_mmq_wg_denoms_k = { 32, 64, 1 };
21152115

21162116
// spec constants and tile sizes for quant matmul_id
21172117
l_warptile_mmqid = { 256, 128, 128, 16, 0 };
@@ -5022,26 +5022,37 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz
50225022
ggml_vk_queue_command_pools_cleanup(dst->device);
50235023
}
50245024

5025-
static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int n, int k, const vk_pipeline& pipeline) {
5025+
static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, const vk_pipeline& pipeline) {
50265026
VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")");
50275027

50285028
uint32_t split_k = 1;
5029-
if (ctx->device->shader_core_count != 0 && m >= (int)pipeline->wg_denoms[0] && n >= (int)pipeline->wg_denoms[1]) {
5029+
if (ctx->device->shader_core_count != 0 && m >= pipeline->wg_denoms[0] && n >= pipeline->wg_denoms[1]) {
50305030
// If k is 'large' and the SMs will fill less than halfway, use split_k.
50315031
uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]);
50325032
uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]);
5033-
if (k >= 2048 && m_tiles * n_tiles < ctx->device->shader_core_count / 2) {
5034-
split_k = ctx->device->shader_core_count / (m_tiles * n_tiles);
5035-
// Clamp to 2 or 4
5036-
split_k = std::min(split_k, 4u);
5037-
if (split_k == 3) {
5038-
split_k = 2;
5033+
5034+
if (k >= 2048) {
5035+
if (m_tiles * n_tiles <= ctx->device->shader_core_count / 2) {
5036+
split_k = ctx->device->shader_core_count / (m_tiles * n_tiles);
5037+
} else if (m_tiles * n_tiles <= ctx->device->shader_core_count * 2 / 3) {
5038+
split_k = 3;
50395039
}
5040-
if (ctx->device->coopmat2) {
5041-
// coopmat2 shader expects splits to be aligned to 256
5042-
while (split_k > 1 && ((k / split_k) % 256) != 0) {
5043-
split_k /= 2;
5040+
// Cap the split at 8x. Unless k is huge this is a lot of overhead.
5041+
split_k = std::min(split_k, 8u);
5042+
5043+
// ggml_vk_matmul will align the splits to be a multiple of 256.
5044+
// If this rounded up size would cause the last split to be empty,
5045+
// then reduce the split count.
5046+
while (true) {
5047+
if (split_k == 1) {
5048+
break;
5049+
}
5050+
uint32_t k_split = CEIL_DIV(k, split_k);
5051+
k_split = ROUNDUP_POW2(k_split, 256);
5052+
if (k_split * (split_k - 1) < k) {
5053+
break;
50445054
}
5055+
split_k--;
50455056
}
50465057
}
50475058
}
@@ -5053,9 +5064,22 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
50535064
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
50545065

50555066
if (ctx->device->coopmat2) {
5067+
const uint32_t shader_core_count = ctx->device->shader_core_count;
5068+
const uint32_t tiles_l = CEIL_DIV(m, mmp->a_l->wg_denoms[0]) * CEIL_DIV(n, mmp->a_l->wg_denoms[1]);
5069+
const uint32_t tiles_m = CEIL_DIV(m, mmp->a_m->wg_denoms[0]) * CEIL_DIV(n, mmp->a_m->wg_denoms[1]);
5070+
50565071
// Use large shader when the N dimension is greater than the medium shader's tile size
50575072
uint32_t crossover_large = mmp->m->wg_denoms[1];
5058-
if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
5073+
5074+
// Prefer large over medium if either:
5075+
// - medium or large tiles would overfill the GPU
5076+
// - large tiles with a split_k==3 fits in the GPU and medium tiles with split_k==2 does not
5077+
// (medium with split_k==2 is probably better if it fits - more workgroups running and less split_k overhead)
5078+
bool prefer_large = tiles_m > shader_core_count || tiles_l > shader_core_count ||
5079+
// split_k==3 with large tiles likely better than medium tiles with no split_k.
5080+
(tiles_l <= shader_core_count / 3 && tiles_m > shader_core_count / 2);
5081+
5082+
if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large && prefer_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
50595083
return aligned ? mmp->a_l : mmp->l;
50605084
}
50615085
// Use medium shader when the N dimension is greater than the small shader's tile size
@@ -5099,7 +5123,11 @@ static void ggml_vk_matmul(
50995123

51005124
GGML_ASSERT(batch_stride_d == m * n);
51015125

5102-
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, padded_n };
5126+
// Round the split size up to a multiple of 256 (k-quant alignment)
5127+
uint32_t k_split = CEIL_DIV(k, split_k);
5128+
k_split = ROUNDUP_POW2(k_split, 256);
5129+
5130+
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
51035131
// Make sure enough workgroups get assigned for split k to work
51045132
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
51055133
ggml_vk_sync_buffers(subctx);

scripts/compare-llama-bench.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def __init__(self, tool: str = "llama-bench"):
326326

327327
# Set table name and schema based on tool
328328
if self.tool == "llama-bench":
329-
self.table_name = "test"
329+
self.table_name = "llama_bench"
330330
db_fields = LLAMA_BENCH_DB_FIELDS
331331
db_types = LLAMA_BENCH_DB_TYPES
332332
elif self.tool == "test-backend-ops":
@@ -409,17 +409,17 @@ def __init__(self, data_file: str, tool: Any):
409409

410410
# Tool selection logic
411411
if tool is None:
412-
if "test" in table_names:
413-
self.table_name = "test"
412+
if "llama_bench" in table_names:
413+
self.table_name = "llama_bench"
414414
self.tool = "llama-bench"
415415
elif "test_backend_ops" in table_names:
416416
self.table_name = "test_backend_ops"
417417
self.tool = "test-backend-ops"
418418
else:
419419
raise RuntimeError(f"No suitable table found in database. Available tables: {table_names}")
420420
elif tool == "llama-bench":
421-
if "test" in table_names:
422-
self.table_name = "test"
421+
if "llama_bench" in table_names:
422+
self.table_name = "llama_bench"
423423
self.tool = "llama-bench"
424424
else:
425425
raise RuntimeError(f"Table 'test' not found for tool 'llama-bench'. Available tables: {table_names}")

tests/test-chat.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -953,6 +953,33 @@ static void test_template_output_parsers() {
953953
/* is_partial= */ false,
954954
{COMMON_CHAT_FORMAT_HERMES_2_PRO}));
955955

956+
// Test multiple tool calls
957+
common_chat_msg message_assist_multiple_calls;
958+
message_assist_multiple_calls.role = "assistant";
959+
message_assist_multiple_calls.content = "";
960+
message_assist_multiple_calls.tool_calls.push_back({"special_function", "{\"arg1\": 1}", ""});
961+
message_assist_multiple_calls.tool_calls.push_back({"python", "{\"code\":\"print('hello')\"}", ""});
962+
963+
assert_msg_equals(
964+
message_assist_multiple_calls,
965+
common_chat_parse(
966+
"<tool_call>\n"
967+
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
968+
"</tool_call>\n"
969+
"<tool_call>\n"
970+
"{\"name\": \"python\", \"arguments\": {\"code\":\"print('hello')\"}}\n"
971+
"</tool_call>",
972+
/* is_partial= */ false,
973+
{COMMON_CHAT_FORMAT_HERMES_2_PRO}));
974+
975+
assert_msg_equals(
976+
message_assist_multiple_calls,
977+
common_chat_parse(
978+
"<function=special_function>{\"arg1\": 1}</function>\n"
979+
"<function=python>{\"code\":\"print('hello')\"}</function>",
980+
/* is_partial= */ false,
981+
{COMMON_CHAT_FORMAT_HERMES_2_PRO}));
982+
956983
assert_msg_equals(
957984
simple_assist_msg(
958985
"This is not a tool call:",
@@ -1039,6 +1066,22 @@ static void test_template_output_parsers() {
10391066
"<tool_call>\n"
10401067
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
10411068
"</tool_call>");
1069+
1070+
// Test multiple tool calls with template
1071+
common_chat_msg message_assist_multiple_calls_template;
1072+
message_assist_multiple_calls_template.role = "assistant";
1073+
message_assist_multiple_calls_template.content = "";
1074+
message_assist_multiple_calls_template.tool_calls.push_back({"special_function", "{\"arg1\": 1}", ""});
1075+
message_assist_multiple_calls_template.tool_calls.push_back({"python", "{\"code\":\"print('test')\"}", ""});
1076+
1077+
test_templates(tmpls.get(), end_tokens, message_assist_multiple_calls_template, tools,
1078+
"<tool_call>\n"
1079+
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
1080+
"</tool_call>\n"
1081+
"<tool_call>\n"
1082+
"{\"name\": \"python\", \"arguments\": {\"code\":\"print('test')\"}}\n"
1083+
"</tool_call>");
1084+
10421085
test_templates(tmpls.get(), end_tokens, message_assist_call_python_lines, tools,
10431086
"<tool_call>\n"
10441087
"{\"name\": \"python\", \"arguments\": {\"code\":\"# This is a program:\\nprint('hey')\"}}\n"

tools/llama-bench/llama-bench.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1738,7 +1738,7 @@ struct sql_printer : public printer {
17381738

17391739
void print_header(const cmd_params & params) override {
17401740
std::vector<std::string> fields = test::get_fields();
1741-
fprintf(fout, "CREATE TABLE IF NOT EXISTS test (\n");
1741+
fprintf(fout, "CREATE TABLE IF NOT EXISTS llama_bench (\n");
17421742
for (size_t i = 0; i < fields.size(); i++) {
17431743
fprintf(fout, " %s %s%s\n", fields.at(i).c_str(), get_sql_field_type(fields.at(i)).c_str(),
17441744
i < fields.size() - 1 ? "," : "");
@@ -1749,7 +1749,7 @@ struct sql_printer : public printer {
17491749
}
17501750

17511751
void print_test(const test & t) override {
1752-
fprintf(fout, "INSERT INTO test (%s) ", join(test::get_fields(), ", ").c_str());
1752+
fprintf(fout, "INSERT INTO llama_bench (%s) ", join(test::get_fields(), ", ").c_str());
17531753
fprintf(fout, "VALUES (");
17541754
std::vector<std::string> values = t.get_values();
17551755
for (size_t i = 0; i < values.size(); i++) {

0 commit comments

Comments
 (0)