Skip to content

Commit 879d673

Browse files
authored
vulkan: Implement top-k (#17418)
* vulkan: Implement top-k Each pass launches workgroups that each sort 2^N elements (where N is usually 7-10) and discards all but the top K. Repeat until only K are left. And there's a fast path when K==1 to just find the max value rather than sorting. * fix pipeline selection * vulkan: Add N-ary search algorithm for topk * microoptimizations
1 parent 6ab4e50 commit 879d673

File tree

5 files changed

+480
-1
lines changed

5 files changed

+480
-1
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,7 @@ enum shader_reduction_mode {
409409
// argsort pipelines for up to 1<<10 invocations per workgroup
410410
static constexpr uint32_t num_argsort_pipelines = 11;
411411
static constexpr uint32_t num_topk_moe_pipelines = 10;
412+
static constexpr uint32_t num_topk_pipelines = 11;
412413

413414
static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
414415
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
@@ -515,6 +516,7 @@ struct vk_device_struct {
515516
bool single_queue;
516517
bool support_async;
517518
uint32_t subgroup_size;
519+
uint32_t subgroup_size_log2;
518520
uint32_t shader_core_count;
519521
bool uma;
520522
bool prefer_host_memory;
@@ -704,6 +706,7 @@ struct vk_device_struct {
704706
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
705707
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
706708
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
709+
vk_pipeline pipeline_topk_f32[num_topk_pipelines];
707710
vk_pipeline pipeline_sum_rows_f32;
708711
vk_pipeline pipeline_cumsum_f32;
709712
vk_pipeline pipeline_argmax_f32;
@@ -1205,6 +1208,15 @@ struct vk_op_argsort_push_constants {
12051208
uint32_t inner_end;
12061209
};
12071210

1211+
struct vk_op_topk_push_constants {
1212+
uint32_t orig_ncols;
1213+
uint32_t ncols_input;
1214+
uint32_t ncols_output;
1215+
uint32_t nrows;
1216+
uint32_t first_pass;
1217+
uint32_t last_pass;
1218+
};
1219+
12081220
struct vk_op_im2col_push_constants {
12091221
uint64_t dst_addr;
12101222
uint32_t batch_offset; uint32_t offset_delta;
@@ -3965,6 +3977,23 @@ static void ggml_vk_load_shaders(vk_device& device) {
39653977
ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE * WG_UNROLL_FACTOR, 1, 1}, {BLOCK_SIZE, WG_UNROLL_FACTOR}, 1, true);
39663978
}
39673979

3980+
for (uint32_t i = 0; i < num_topk_pipelines; ++i) {
3981+
const uint32_t BLOCK_SIZE = 1u << i;
3982+
const uint32_t NCOLS_PADDED_LOG2 = i;
3983+
if (i <= device->max_workgroup_size_log2) {
3984+
uint32_t nary_shmem = 2 * sizeof(int) * BLOCK_SIZE +
3985+
sizeof(int) * device->subgroup_size +
3986+
2 * sizeof(int) +
3987+
(BLOCK_SIZE / device->subgroup_size) * sizeof(int);
3988+
if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot &&
3989+
nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) {
3990+
ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, device->subgroup_size, device->subgroup_size_log2}, 1, true, true, device->subgroup_size);
3991+
} else if (2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) {
3992+
ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_argsort_f32_len, topk_argsort_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true);
3993+
}
3994+
}
3995+
}
3996+
39683997
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
39693998

39703999
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
@@ -4336,6 +4365,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
43364365
device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size);
43374366

43384367
device->subgroup_size = subgroup_props.subgroupSize;
4368+
device->subgroup_size_log2 = uint32_t(log2f(float(device->subgroup_size)));
43394369
device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
43404370
if (sm_builtins) {
43414371
device->shader_core_count = sm_props.shaderSMCount;
@@ -10143,6 +10173,104 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
1014310173
}
1014410174
}
1014510175

10176+
static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10177+
uint32_t ncols = src0->ne[0];
10178+
uint32_t nrows = ggml_nrows(src0);
10179+
uint32_t k = dst->ne[0];
10180+
10181+
vk_op_topk_push_constants pc { ncols, ncols, k, nrows, 0, 0 };
10182+
10183+
// Reserve space for ivec2 per element, double buffered
10184+
const size_t dbl_buf_size = size_t{ncols} * nrows * 2 * sizeof(int);
10185+
const size_t x_sz = dbl_buf_size * 2;
10186+
uint32_t dbl_buf_index = 0;
10187+
10188+
if (ctx->prealloc_size_x < x_sz) {
10189+
ctx->prealloc_size_x = x_sz;
10190+
ggml_vk_preallocate_buffers(ctx, subctx);
10191+
}
10192+
if (ctx->prealloc_x_need_sync) {
10193+
ggml_vk_sync_buffers(ctx, subctx);
10194+
}
10195+
10196+
std::array<uint32_t, 3> elements;
10197+
elements[1] = std::min(nrows, ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
10198+
elements[2] = 1;
10199+
10200+
uint32_t num_elements = ncols;
10201+
10202+
// Each iteration reduces a workgroup's worth of elements down to the K
10203+
// largest elements. Repeat until we have the top K elements.
10204+
// Need to do at least one iteration to write out the results.
10205+
bool done_one_iter = false;
10206+
while (num_elements > k || !done_one_iter) {
10207+
done_one_iter = true;
10208+
10209+
// Prefer going as small as num_topk_pipelines - 3 for perf reasons.
10210+
// But if K is larger, then we need a larger workgroup
10211+
uint32_t max_pipeline = num_topk_pipelines - 3;
10212+
uint32_t min_pipeline = (uint32_t)log2f(float(k)) + 1;
10213+
// require full subgroup
10214+
min_pipeline = std::max(min_pipeline, ctx->device->subgroup_size_log2);
10215+
10216+
uint32_t pipeline_idx = (uint32_t)ceilf(log2f(float(num_elements)));
10217+
pipeline_idx = std::min(pipeline_idx, max_pipeline);
10218+
pipeline_idx = std::max(pipeline_idx, min_pipeline);
10219+
10220+
if (num_elements > (1u << pipeline_idx)) {
10221+
// If we could finish on this loop iteration (i.e. a single workgroup)
10222+
// then do so. It's better than the overhead of another pass.
10223+
for (uint32_t i = pipeline_idx; i < num_topk_pipelines; ++i) {
10224+
if (num_elements <= (1u << i)) {
10225+
pipeline_idx = i;
10226+
break;
10227+
}
10228+
}
10229+
}
10230+
10231+
vk_pipeline pipeline = ctx->device->pipeline_topk_f32[pipeline_idx];
10232+
// If the device doesn't support a pipeline this large, use smaller
10233+
while (!pipeline) {
10234+
pipeline_idx--;
10235+
GGML_ASSERT(pipeline_idx >= min_pipeline);
10236+
pipeline = ctx->device->pipeline_topk_f32[pipeline_idx];
10237+
}
10238+
10239+
vk_op_topk_push_constants pc2 = pc;
10240+
pc2.ncols_input = num_elements;
10241+
10242+
// Number of elements remaining after this pass
10243+
uint32_t num_dst_elements = (num_elements / pipeline->wg_denoms[0]) * k + std::min(k, num_elements % pipeline->wg_denoms[0]);
10244+
10245+
vk_subbuffer src_buf;
10246+
vk_subbuffer dst_buf;
10247+
10248+
if (num_elements == ncols) {
10249+
pc2.first_pass = 1;
10250+
src_buf = ggml_vk_tensor_subbuffer(ctx, src0);
10251+
} else {
10252+
src_buf = { ctx->prealloc_x, dbl_buf_index * dbl_buf_size, dbl_buf_size };
10253+
}
10254+
if (num_dst_elements == k) {
10255+
pc2.last_pass = 1;
10256+
dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
10257+
} else {
10258+
dst_buf = { ctx->prealloc_x, (dbl_buf_index ^ 1) * dbl_buf_size, dbl_buf_size };
10259+
}
10260+
10261+
elements[0] = num_elements;
10262+
10263+
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
10264+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src_buf, dst_buf }, pc2, elements);
10265+
num_elements = num_dst_elements;
10266+
dbl_buf_index ^= 1;
10267+
if (num_elements > k) {
10268+
ggml_vk_sync_buffers(ctx, subctx);
10269+
}
10270+
}
10271+
ctx->prealloc_x_need_sync = true;
10272+
}
10273+
1014610274
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
1014710275
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0));
1014810276
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SUM, p);
@@ -11755,6 +11883,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1175511883
ggml_vk_argsort(ctx, compute_ctx, src0, node);
1175611884
}
1175711885

11886+
break;
11887+
case GGML_OP_TOP_K:
11888+
ggml_vk_topk(ctx, compute_ctx, src0, node);
11889+
1175811890
break;
1175911891
case GGML_OP_SUM:
1176011892
ggml_vk_sum(ctx, compute_ctx, src0, node);
@@ -13787,6 +13919,22 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1378713919
return op->ne[0] <= (1 << device->max_workgroup_size_log2);
1378813920
}
1378913921
}
13922+
case GGML_OP_TOP_K:
13923+
{
13924+
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
13925+
return false;
13926+
}
13927+
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
13928+
auto device = ggml_vk_get_device(ctx->device);
13929+
// We could potentially support larger, using argsort to sort the
13930+
// whole thing. Not clear if this is needed.
13931+
uint32_t min_pipeline = (uint32_t)log2f(float(op->ne[0])) + 1;
13932+
if (min_pipeline >= num_topk_pipelines ||
13933+
!device->pipeline_topk_f32[min_pipeline]) {
13934+
return false;
13935+
}
13936+
}
13937+
return true;
1379013938
case GGML_OP_UPSCALE:
1379113939
case GGML_OP_ACC:
1379213940
case GGML_OP_CONCAT:
@@ -14459,6 +14607,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1445914607
tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]);
1446014608
} else if (tensor->op == GGML_OP_ARGSORT) {
1446114609
tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params);
14610+
} else if (tensor->op == GGML_OP_TOP_K) {
14611+
tensor_clone = ggml_top_k(ggml_ctx, src_clone[0], tensor->ne[0]);
1446214612
} else if (tensor->op == GGML_OP_SUM) {
1446314613
tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
1446414614
} else if (tensor->op == GGML_OP_SUM_ROWS) {
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
#version 450
2+
#extension GL_EXT_control_flow_attributes : enable
3+
4+
#include "types.glsl"
5+
6+
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
7+
layout(constant_id = 1) const int NCOLS_PADDED_LOG2 = 10;
8+
9+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
10+
11+
// Input can either be the source (A) or intermediate values (S).
12+
// Similarly, output can be either destination (D) or intermediate values (S).
13+
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
14+
layout (binding = 0) readonly buffer S {ivec2 data_s[];};
15+
layout (binding = 1) writeonly buffer D {int data_d[];};
16+
layout (binding = 1) writeonly buffer T {ivec2 data_t[];};
17+
18+
layout (push_constant) uniform parameter {
19+
uint orig_ncols;
20+
uint ncols_input;
21+
uint ncols_output;
22+
uint nrows;
23+
uint first_pass;
24+
uint last_pass;
25+
} p;
26+
27+
// pairs of (gid, value)
28+
shared ivec2 dst_row[BLOCK_SIZE];
29+
30+
void topk(bool needs_bounds_check, const uint row) {
31+
const int col = int(gl_LocalInvocationID.x);
32+
33+
// initialize indices
34+
if (gl_GlobalInvocationID.x < p.ncols_input) {
35+
if (p.first_pass != 0) {
36+
const uint row_offset = row * p.ncols_input;
37+
dst_row[col] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
38+
} else {
39+
const uint row_offset = row * p.orig_ncols;
40+
dst_row[col] = data_s[row_offset + gl_GlobalInvocationID.x];
41+
}
42+
} else {
43+
dst_row[col] = ivec2(p.orig_ncols, 0);
44+
}
45+
barrier();
46+
47+
if (p.ncols_output == 1) {
48+
// Fast path for single output - just do a max reduction
49+
[[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
50+
if (col < s) {
51+
ivec2 a = dst_row[col];
52+
ivec2 b = dst_row[col + s];
53+
if (a.x >= p.orig_ncols ||
54+
b.x < p.orig_ncols && b.y > a.y) {
55+
dst_row[col] = b;
56+
}
57+
}
58+
barrier();
59+
}
60+
} else {
61+
// bitonic sort on this group of elements
62+
uint num_outer_loop_iters = NCOLS_PADDED_LOG2;
63+
for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) {
64+
uint num_inner_loop_iters = outer_idx + 1;
65+
for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) {
66+
const int ixj = int(col ^ j);
67+
68+
int idx_0 = (col & k) == 0 ? col : ixj;
69+
int idx_1 = (col & k) == 0 ? ixj : col;
70+
71+
ivec2 sh_idx_0 = dst_row[idx_0];
72+
ivec2 sh_idx_1 = dst_row[idx_1];
73+
bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.orig_ncols : false;
74+
bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.orig_ncols : false;
75+
76+
if ((idx_0_oob ||
77+
(!idx_1_oob && intBitsToFloat(sh_idx_0.y) < intBitsToFloat(sh_idx_1.y))) && (ixj > col)) {
78+
dst_row[idx_0] = sh_idx_1;
79+
dst_row[idx_1] = sh_idx_0;
80+
}
81+
82+
barrier();
83+
}
84+
}
85+
}
86+
87+
if (col < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) {
88+
if (p.last_pass != 0) {
89+
const uint row_offset = row * p.ncols_output;
90+
data_d[row_offset + col] = dst_row[col].x;
91+
} else {
92+
const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output;
93+
data_t[row_offset + col] = dst_row[col];
94+
}
95+
}
96+
}
97+
98+
void main() {
99+
// Fast path for fully occupied workgroups
100+
if ((p.ncols_input % BLOCK_SIZE) == 0) {
101+
uint row = gl_WorkGroupID.y;
102+
while (row < p.nrows) {
103+
topk(false, row);
104+
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
105+
}
106+
} else {
107+
uint row = gl_WorkGroupID.y;
108+
while (row < p.nrows) {
109+
topk(true, row);
110+
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
111+
}
112+
}
113+
}

0 commit comments

Comments
 (0)