@@ -409,6 +409,7 @@ enum shader_reduction_mode {
409409// argsort pipelines for up to 1<<10 invocations per workgroup
410410static constexpr uint32_t num_argsort_pipelines = 11;
411411static constexpr uint32_t num_topk_moe_pipelines = 10;
412+ static constexpr uint32_t num_topk_pipelines = 11;
412413
413414static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
414415 GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
@@ -515,6 +516,7 @@ struct vk_device_struct {
515516 bool single_queue;
516517 bool support_async;
517518 uint32_t subgroup_size;
519+ uint32_t subgroup_size_log2;
518520 uint32_t shader_core_count;
519521 bool uma;
520522 bool prefer_host_memory;
@@ -704,6 +706,7 @@ struct vk_device_struct {
704706 vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
705707 vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
706708 vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
709+ vk_pipeline pipeline_topk_f32[num_topk_pipelines];
707710 vk_pipeline pipeline_sum_rows_f32;
708711 vk_pipeline pipeline_cumsum_f32;
709712 vk_pipeline pipeline_argmax_f32;
@@ -1205,6 +1208,15 @@ struct vk_op_argsort_push_constants {
12051208 uint32_t inner_end;
12061209};
12071210
1211+ struct vk_op_topk_push_constants {
1212+ uint32_t orig_ncols;
1213+ uint32_t ncols_input;
1214+ uint32_t ncols_output;
1215+ uint32_t nrows;
1216+ uint32_t first_pass;
1217+ uint32_t last_pass;
1218+ };
1219+
12081220struct vk_op_im2col_push_constants {
12091221 uint64_t dst_addr;
12101222 uint32_t batch_offset; uint32_t offset_delta;
@@ -3965,6 +3977,23 @@ static void ggml_vk_load_shaders(vk_device& device) {
39653977 ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE * WG_UNROLL_FACTOR, 1, 1}, {BLOCK_SIZE, WG_UNROLL_FACTOR}, 1, true);
39663978 }
39673979
3980+ for (uint32_t i = 0; i < num_topk_pipelines; ++i) {
3981+ const uint32_t BLOCK_SIZE = 1u << i;
3982+ const uint32_t NCOLS_PADDED_LOG2 = i;
3983+ if (i <= device->max_workgroup_size_log2) {
3984+ uint32_t nary_shmem = 2 * sizeof(int) * BLOCK_SIZE +
3985+ sizeof(int) * device->subgroup_size +
3986+ 2 * sizeof(int) +
3987+ (BLOCK_SIZE / device->subgroup_size) * sizeof(int);
3988+ if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot &&
3989+ nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) {
3990+ ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, device->subgroup_size, device->subgroup_size_log2}, 1, true, true, device->subgroup_size);
3991+ } else if (2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) {
3992+ ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_argsort_f32_len, topk_argsort_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true);
3993+ }
3994+ }
3995+ }
3996+
39683997 ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
39693998
39703999 ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
@@ -4336,6 +4365,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
43364365 device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size);
43374366
43384367 device->subgroup_size = subgroup_props.subgroupSize;
4368+ device->subgroup_size_log2 = uint32_t(log2f(float(device->subgroup_size)));
43394369 device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
43404370 if (sm_builtins) {
43414371 device->shader_core_count = sm_props.shaderSMCount;
@@ -10143,6 +10173,104 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
1014310173 }
1014410174}
1014510175
10176+ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10177+ uint32_t ncols = src0->ne[0];
10178+ uint32_t nrows = ggml_nrows(src0);
10179+ uint32_t k = dst->ne[0];
10180+
10181+ vk_op_topk_push_constants pc { ncols, ncols, k, nrows, 0, 0 };
10182+
10183+ // Reserve space for ivec2 per element, double buffered
10184+ const size_t dbl_buf_size = size_t{ncols} * nrows * 2 * sizeof(int);
10185+ const size_t x_sz = dbl_buf_size * 2;
10186+ uint32_t dbl_buf_index = 0;
10187+
10188+ if (ctx->prealloc_size_x < x_sz) {
10189+ ctx->prealloc_size_x = x_sz;
10190+ ggml_vk_preallocate_buffers(ctx, subctx);
10191+ }
10192+ if (ctx->prealloc_x_need_sync) {
10193+ ggml_vk_sync_buffers(ctx, subctx);
10194+ }
10195+
10196+ std::array<uint32_t, 3> elements;
10197+ elements[1] = std::min(nrows, ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
10198+ elements[2] = 1;
10199+
10200+ uint32_t num_elements = ncols;
10201+
10202+ // Each iteration reduces a workgroup's worth of elements down to the K
10203+ // largest elements. Repeat until we have the top K elements.
10204+ // Need to do at least one iteration to write out the results.
10205+ bool done_one_iter = false;
10206+ while (num_elements > k || !done_one_iter) {
10207+ done_one_iter = true;
10208+
10209+ // Prefer going as small as num_topk_pipelines - 3 for perf reasons.
10210+ // But if K is larger, then we need a larger workgroup
10211+ uint32_t max_pipeline = num_topk_pipelines - 3;
10212+ uint32_t min_pipeline = (uint32_t)log2f(float(k)) + 1;
10213+ // require full subgroup
10214+ min_pipeline = std::max(min_pipeline, ctx->device->subgroup_size_log2);
10215+
10216+ uint32_t pipeline_idx = (uint32_t)ceilf(log2f(float(num_elements)));
10217+ pipeline_idx = std::min(pipeline_idx, max_pipeline);
10218+ pipeline_idx = std::max(pipeline_idx, min_pipeline);
10219+
10220+ if (num_elements > (1u << pipeline_idx)) {
10221+ // If we could finish on this loop iteration (i.e. a single workgroup)
10222+ // then do so. It's better than the overhead of another pass.
10223+ for (uint32_t i = pipeline_idx; i < num_topk_pipelines; ++i) {
10224+ if (num_elements <= (1u << i)) {
10225+ pipeline_idx = i;
10226+ break;
10227+ }
10228+ }
10229+ }
10230+
10231+ vk_pipeline pipeline = ctx->device->pipeline_topk_f32[pipeline_idx];
10232+ // If the device doesn't support a pipeline this large, use smaller
10233+ while (!pipeline) {
10234+ pipeline_idx--;
10235+ GGML_ASSERT(pipeline_idx >= min_pipeline);
10236+ pipeline = ctx->device->pipeline_topk_f32[pipeline_idx];
10237+ }
10238+
10239+ vk_op_topk_push_constants pc2 = pc;
10240+ pc2.ncols_input = num_elements;
10241+
10242+ // Number of elements remaining after this pass
10243+ uint32_t num_dst_elements = (num_elements / pipeline->wg_denoms[0]) * k + std::min(k, num_elements % pipeline->wg_denoms[0]);
10244+
10245+ vk_subbuffer src_buf;
10246+ vk_subbuffer dst_buf;
10247+
10248+ if (num_elements == ncols) {
10249+ pc2.first_pass = 1;
10250+ src_buf = ggml_vk_tensor_subbuffer(ctx, src0);
10251+ } else {
10252+ src_buf = { ctx->prealloc_x, dbl_buf_index * dbl_buf_size, dbl_buf_size };
10253+ }
10254+ if (num_dst_elements == k) {
10255+ pc2.last_pass = 1;
10256+ dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
10257+ } else {
10258+ dst_buf = { ctx->prealloc_x, (dbl_buf_index ^ 1) * dbl_buf_size, dbl_buf_size };
10259+ }
10260+
10261+ elements[0] = num_elements;
10262+
10263+ ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
10264+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src_buf, dst_buf }, pc2, elements);
10265+ num_elements = num_dst_elements;
10266+ dbl_buf_index ^= 1;
10267+ if (num_elements > k) {
10268+ ggml_vk_sync_buffers(ctx, subctx);
10269+ }
10270+ }
10271+ ctx->prealloc_x_need_sync = true;
10272+ }
10273+
1014610274static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
1014710275 vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0));
1014810276 ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SUM, p);
@@ -11755,6 +11883,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1175511883 ggml_vk_argsort(ctx, compute_ctx, src0, node);
1175611884 }
1175711885
11886+ break;
11887+ case GGML_OP_TOP_K:
11888+ ggml_vk_topk(ctx, compute_ctx, src0, node);
11889+
1175811890 break;
1175911891 case GGML_OP_SUM:
1176011892 ggml_vk_sum(ctx, compute_ctx, src0, node);
@@ -13787,6 +13919,22 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1378713919 return op->ne[0] <= (1 << device->max_workgroup_size_log2);
1378813920 }
1378913921 }
13922+ case GGML_OP_TOP_K:
13923+ {
13924+ if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
13925+ return false;
13926+ }
13927+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
13928+ auto device = ggml_vk_get_device(ctx->device);
13929+ // We could potentially support larger, using argsort to sort the
13930+ // whole thing. Not clear if this is needed.
13931+ uint32_t min_pipeline = (uint32_t)log2f(float(op->ne[0])) + 1;
13932+ if (min_pipeline >= num_topk_pipelines ||
13933+ !device->pipeline_topk_f32[min_pipeline]) {
13934+ return false;
13935+ }
13936+ }
13937+ return true;
1379013938 case GGML_OP_UPSCALE:
1379113939 case GGML_OP_ACC:
1379213940 case GGML_OP_CONCAT:
@@ -14459,6 +14607,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1445914607 tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]);
1446014608 } else if (tensor->op == GGML_OP_ARGSORT) {
1446114609 tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params);
14610+ } else if (tensor->op == GGML_OP_TOP_K) {
14611+ tensor_clone = ggml_top_k(ggml_ctx, src_clone[0], tensor->ne[0]);
1446214612 } else if (tensor->op == GGML_OP_SUM) {
1446314613 tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
1446414614 } else if (tensor->op == GGML_OP_SUM_ROWS) {
0 commit comments