Skip to content

Commit 59d8d4e

Browse files
authored
vulkan: improve topk perf for large k, fix overflow in unit tests (#17582)
1 parent d82b7a7 commit 59d8d4e

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10239,7 +10239,9 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
1023910239

1024010240
// Prefer going as small as num_topk_pipelines - 3 for perf reasons.
1024110241
// But if K is larger, then we need a larger workgroup
10242-
uint32_t max_pipeline = num_topk_pipelines - 3;
10242+
uint32_t max_pipeline = num_topk_pipelines - 1;
10243+
uint32_t preferred_pipeline = std::max(num_topk_pipelines - 3, (uint32_t)log2f(float(k)) + 2);
10244+
max_pipeline = std::min(preferred_pipeline, max_pipeline);
1024310245
uint32_t min_pipeline = (uint32_t)log2f(float(k)) + 1;
1024410246
// require full subgroup
1024510247
min_pipeline = std::max(min_pipeline, ctx->device->subgroup_size_log2);

tests/test-backend-ops.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1446,14 +1446,14 @@ struct test_case {
14461446
const uint64_t target_flops_cpu = 8ULL * GFLOP;
14471447
const uint64_t target_flops_gpu = 100ULL * GFLOP;
14481448
uint64_t target_flops = is_cpu ? target_flops_cpu : target_flops_gpu;
1449-
n_runs = std::min<int>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_flops / op_flops(out)) + 1;
1449+
n_runs = (int)std::min<int64_t>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_flops / op_flops(out)) + 1;
14501450
} else {
14511451
// based on memory size
14521452
const size_t GB = 1ULL << 30;
14531453
const size_t target_size_cpu = 8 * GB;
14541454
const size_t target_size_gpu = 32 * GB;
14551455
size_t target_size = is_cpu ? target_size_cpu : target_size_gpu;
1456-
n_runs = std::min<int>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1;
1456+
n_runs = (int)std::min<int64_t>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1;
14571457
}
14581458

14591459
// duplicate the op
@@ -8043,7 +8043,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
80438043
}
80448044

80458045
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1}));
8046-
for (auto k : {1, 10, 40}) {
8046+
8047+
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2, 1, 1, 1}, 1));
8048+
for (auto k : {1, 10, 40, 400}) {
80478049
for (auto nrows : {1, 16}) {
80488050
for (auto cols : {k, 1000, 65000, 200000}) {
80498051
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {cols, nrows, 1, 1}, k));

0 commit comments

Comments
 (0)