diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index d0fb3bccad225..18f095b896010 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -565,14 +565,23 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) { #define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x) #define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x) +static inline int32_t ggml_node_get_use_count(const struct ggml_cgraph * cgraph, int node_idx) { + const struct ggml_tensor * node = cgraph->nodes[node_idx]; + + size_t hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node); + if (!ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos)) { + return 0; + } + return cgraph->use_counts[hash_pos]; +} + // return true if the node's results are only used by N other nodes // and can be fused into their calculations. static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int node_idx, int32_t n_uses) { const struct ggml_tensor * node = cgraph->nodes[node_idx]; // check the use count against how many we're replacing - size_t hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node); - if (!ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos) || cgraph->use_counts[hash_pos] != n_uses) { + if (ggml_node_get_use_count(cgraph, node_idx) != n_uses) { return false; } diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index bc703611f0a94..21bd052255564 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -385,6 +385,14 @@ enum shader_reduction_mode { static constexpr uint32_t num_argsort_pipelines = 11; static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1); +static constexpr uint32_t num_topk_moe_pipelines = 10; + +static constexpr std::array topk_moe_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, + GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, + GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE }; +static constexpr std::array topk_moe { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, + GGML_OP_VIEW, GGML_OP_GET_ROWS }; + struct vk_device_struct { std::recursive_mutex mutex; @@ -598,6 +606,9 @@ struct vk_device_struct { vk_pipeline pipeline_flash_attn_split_k_reduce; + // [2] is {!norm, norm} + vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2]; + std::vector all_pipelines; std::vector> pinned_memory; @@ -941,6 +952,11 @@ struct vk_op_multi_add_push_constants { static_assert(MAX_PARAMETER_COUNT == 12); static_assert(sizeof(vk_op_multi_add_push_constants) <= 256); +struct vk_op_topk_moe_push_constants { + uint32_t n_rows; + uint32_t n_expert_used; +}; + struct vk_op_add_id_push_constants { uint32_t ne0; uint32_t ne1; @@ -3722,6 +3738,11 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); + for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) { + ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][0], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<pipeline_topk_moe[i][1], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); + if (ctx->num_additional_fused_ops) { + uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); + GGML_ASSERT(idx < num_topk_moe_pipelines); + bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1; + return ctx->device->pipeline_topk_moe[idx][with_norm]; + } + if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32; } @@ -9589,6 +9617,87 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun); } +static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) { + + bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1; + ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0]; + ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4]; + ggml_tensor * ids = cgraph->nodes[node_idx + 3]; + + GGML_ASSERT(logits->type == GGML_TYPE_F32); + GGML_ASSERT(weights->type == GGML_TYPE_F32); + GGML_ASSERT(ids->type == GGML_TYPE_I32); + + const int n_experts = logits->ne[0]; + const int n_rows = logits->ne[1]; + const int n_expert_used = weights->ne[1]; + + GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts); + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, cgraph->nodes[node_idx], GGML_OP_SOFT_MAX); + + if (dryrun) { + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + return; + } + + ggml_backend_vk_buffer_context * logits_buf_ctx = (ggml_backend_vk_buffer_context *)logits->buffer->context; + ggml_backend_vk_buffer_context * weights_buf_ctx = (ggml_backend_vk_buffer_context *)weights->buffer->context; + ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context; + + vk_buffer d_logits = nullptr; + size_t logits_buf_offset = 0; + vk_buffer d_weights = nullptr; + size_t weights_buf_offset = 0; + vk_buffer d_ids = nullptr; + size_t ids_buf_offset = 0; + + bool logits_uma = false; + bool weights_uma = false; + bool ids_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, logits->data, d_logits, logits_buf_offset); + ggml_vk_host_get(ctx->device, weights->data, d_weights, weights_buf_offset); + ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset); + logits_uma = d_logits != nullptr; + weights_uma = d_weights != nullptr; + ids_uma = d_ids != nullptr; + } + + if (!logits_uma) { + d_logits = logits_buf_ctx->dev_buffer; + logits_buf_offset = vk_tensor_offset(logits) + logits->view_offs; + GGML_ASSERT(d_logits != nullptr); + } + if (!weights_uma) { + d_weights = weights_buf_ctx->dev_buffer; + weights_buf_offset = vk_tensor_offset(weights) + weights->view_offs; + GGML_ASSERT(d_weights != nullptr); + } + if (!ids_uma) { + d_ids = ids_buf_ctx->dev_buffer; + ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs; + GGML_ASSERT(d_ids != nullptr); + } + + vk_op_topk_moe_push_constants pc; + pc.n_rows = n_rows; + pc.n_expert_used = n_expert_used; + + GGML_ASSERT(n_expert_used <= n_experts); + + const uint32_t rows_per_block = 4; + std::array elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + { + ggml_vk_subbuffer(ctx, d_logits, logits_buf_offset), + ggml_vk_subbuffer(ctx, d_weights, weights_buf_offset), + ggml_vk_subbuffer(ctx, d_ids, ids_buf_offset), + }, pc, elements); +} + static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) { const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; @@ -11174,11 +11283,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr ctx->unsynced_nodes_read.clear(); ggml_vk_sync_buffers(ctx, compute_ctx); } - // Add the last fused node and all fused source nodes to the unsynchronized list. - const ggml_tensor * last_node = cgraph->nodes[node_idx + ctx->num_additional_fused_ops]; - ctx->unsynced_nodes_written.push_back(last_node); + // Add all fused nodes to the unsynchronized lists. for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) { const ggml_tensor *cur_node = cgraph->nodes[node_idx + i]; + // Multiple outputs could be written, e.g. in topk_moe. Add them all to the list. + ctx->unsynced_nodes_written.push_back(cur_node); for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) { if (!cur_node->src[j]) { continue; @@ -11345,7 +11454,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_SOFT_MAX: - ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun); + if (ctx->num_additional_fused_ops) { + ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx, dryrun); + } else { + ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun); + } break; case GGML_OP_SOFT_MAX_BACK: @@ -12141,6 +12254,120 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st return true; } +static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, + int node_idx, bool with_norm) { + + if (with_norm) { + if (node_idx + (int)topk_moe_norm.size() > cgraph->n_nodes) { + return false; + } + for (size_t i = 0; i < topk_moe_norm.size(); ++i) { + if (cgraph->nodes[node_idx + i]->op != topk_moe_norm[i]) { + return false; + } + } + } else { + if (node_idx + (int)topk_moe.size() > cgraph->n_nodes) { + return false; + } + for (size_t i = 0; i < topk_moe.size(); ++i) { + if (cgraph->nodes[node_idx + i]->op != topk_moe[i]) { + return false; + } + } + } + + const ggml_tensor * softmax = cgraph->nodes[node_idx + 0]; + const ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4]; + + const float * op_params = (const float *)softmax->op_params; + + float scale = op_params[0]; + float max_bias = op_params[1]; + + if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) { + return false; + } + + if (scale != 1.0f || max_bias != 0.0f) { + return false; + } + + // don't fuse when masks or sinks are present + if (softmax->src[1] || softmax->src[2]) { + return false; + } + + const int n_expert = softmax->ne[0]; + // n_expert must be a power of 2 + if (!is_pow2(n_expert) || n_expert > (1 << (num_topk_moe_pipelines-1))) { + return false; + } + + // Check that the nodes don't have any unexpected uses + const ggml_tensor * reshape1 = cgraph->nodes[node_idx + 1]; + const ggml_tensor * argsort = cgraph->nodes[node_idx + 2]; + const ggml_tensor * view = cgraph->nodes[node_idx + 3]; + const ggml_tensor * get_rows = cgraph->nodes[node_idx + 4]; + const ggml_tensor * reshape5 = with_norm ? cgraph->nodes[node_idx + 5] : nullptr; + const ggml_tensor * sum_rows = with_norm ? cgraph->nodes[node_idx + 6] : nullptr; + const ggml_tensor * div = with_norm ? cgraph->nodes[node_idx + 7] : nullptr; + const ggml_tensor * reshape8 = with_norm ? cgraph->nodes[node_idx + 8] : nullptr; + + // softmax is used by reshape and argsort + if (ggml_node_get_use_count(cgraph, node_idx) != 2 || + reshape1->src[0] != softmax || + argsort->src[0] != softmax) { + return false; + } + // reshape is used by get_rows + if (ggml_node_get_use_count(cgraph, node_idx + 1) != 1 || + get_rows->src[0] != reshape1) { + return false; + } + // argsort is used by view + if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 || + view->src[0] != argsort) { + return false; + } + // view is written (via argsort), we can skip checking it + + if (with_norm) { + // get_rows is used by reshape + if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 || + reshape5->src[0] != get_rows) { + return false; + } + + // reshape is used by sum_rows and div + if (ggml_node_get_use_count(cgraph, node_idx + 5) != 2 || + sum_rows->src[0] != reshape5 || + div->src[0] != reshape5) { + return false; + } + + // sum_rows is used by div + if (ggml_node_get_use_count(cgraph, node_idx + 6) != 1 || + div->src[1] != sum_rows) { + return false; + } + + // div/reshape are written + if (reshape8->src[0] != div) { + return false; + } + } + + if (!ctx->device->subgroup_arithmetic || + !ctx->device->subgroup_shuffle || + !ctx->device->subgroup_require_full_support || + ctx->device->disable_fusion) { + return false; + } + + return true; +} + static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) { const ggml_tensor *first_node = cgraph->nodes[node_idx]; @@ -12216,6 +12443,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->num_additional_fused_ops = num_adds - 1; } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { ctx->num_additional_fused_ops = 1; + } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) { + ctx->num_additional_fused_ops = topk_moe_norm.size() - 1; + } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) { + ctx->num_additional_fused_ops = topk_moe.size() - 1; } } ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false); @@ -12313,6 +12544,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->num_additional_fused_ops = num_adds - 1; } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { ctx->num_additional_fused_ops = 1; + } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) { + ctx->num_additional_fused_ops = topk_moe_norm.size() - 1; + } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) { + ctx->num_additional_fused_ops = topk_moe.size() - 1; } } @@ -12320,10 +12555,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5; bool submit = (submitted_nodes >= nodes_per_submit) || (mul_mat_bytes >= mul_mat_bytes_per_submit) || - (i + ctx->num_additional_fused_ops == last_node) || + (i + ctx->num_additional_fused_ops >= last_node) || (almost_ready && !ctx->almost_ready_fence_pending); - bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops == last_node, almost_ready, submit); + bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops >= last_node, almost_ready, submit); if (vk_perf_logger_enabled) { if (ctx->compute_ctx.expired()) { @@ -12444,6 +12679,25 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * while (first_unused < graph->n_nodes) { std::vector current_set; + // Avoid reordering topk_moe_norm + if (first_unused + (int)topk_moe_norm.size() <= graph->n_nodes) { + bool is_topk_moe_norm = true; + for (size_t j = 0; j < topk_moe_norm.size(); ++j) { + if (graph->nodes[first_unused + j]->op != topk_moe_norm[j] || used[first_unused + j]) { + is_topk_moe_norm = false; + } + } + if (is_topk_moe_norm) { + for (size_t j = 0; j < topk_moe_norm.size(); ++j) { + new_order.push_back(graph->nodes[first_unused + j]); + used[first_unused + j] = true; + } + while (first_unused < graph->n_nodes && used[first_unused]) { + first_unused++; + } + continue; + } + } // First, grab the next unused node. current_set.push_back(first_unused); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp new file mode 100644 index 0000000000000..9e56d5f8a3cc1 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp @@ -0,0 +1,139 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_shuffle : enable + +#include "types.glsl" + +layout (push_constant) uniform parameter +{ + uint n_rows; + uint n_expert_used; +}; + +layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in; + +layout(constant_id = 0) const uint WARP_SIZE = 32; +layout(constant_id = 1) const uint n_experts = 512; +layout(constant_id = 2) const bool with_norm = true; + +const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1; + +layout (binding = 0, std430) readonly buffer Logits {float logits[];}; +layout (binding = 1, std430) writeonly buffer Weights {float weights[];}; +layout (binding = 2, std430) writeonly buffer Ids {uint ids[];}; + +void main() { + const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y; + if (row >= n_rows) { + return; + } + + const uint logits_offset = n_experts * row; + const uint weights_offset = n_expert_used * row; + const uint ids_offset = n_experts * row; + + float logits_r[experts_per_thread]; + + const float INFINITY = 1.0 / 0.0; + + [[unroll]] + for (uint i = 0; i < n_experts; i += WARP_SIZE) { + const uint expert = i + gl_LocalInvocationID.x; + logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[logits_offset + expert] : -INFINITY; + } + + float max_val = logits_r[0]; + + [[unroll]] + for (int i = 1; i < experts_per_thread; i++) { + const float val = logits_r[i]; + max_val = max(val, max_val); + } + + max_val = subgroupMax(max_val); + + float wt[experts_per_thread]; + float tmp = 0.f; + + [[unroll]] + for (int i = 0; i < experts_per_thread; i++) { + const float val = logits_r[i]; + wt[i] = exp(val - max_val); + tmp += wt[i]; + } + + tmp = subgroupAdd(tmp); + + const float inv_sum = 1.0f / tmp; + + [[unroll]] + for (int i = 0; i < experts_per_thread; i++) { + wt[i] = wt[i] * inv_sum; + } + + // at this point, each thread holds a portion of softmax, + // we do the argmax reduce over n_expert_used, each time marking + // the expert weight as -inf to exclude from the next iteration + + float wt_sum = 0.f; + + float output_weights[experts_per_thread]; + + for (int k = 0; k < n_expert_used; k++) { + float max_val = wt[0]; + uint max_expert = gl_LocalInvocationID.x; + + [[unroll]] + for (int i = 1; i < experts_per_thread; i++) { + const uint expert = gl_LocalInvocationID.x + i * WARP_SIZE; + if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) { + max_val = wt[i]; + max_expert = expert; + } + } + + [[unroll]] + for (uint mask = WARP_SIZE / 2; mask > 0; mask /= 2) { + const float val = subgroupShuffleXor(max_val, mask); + const uint expert = subgroupShuffleXor(max_expert, mask); + if (val > max_val || (val == max_val && expert < max_expert)) { + max_val = val; + max_expert = expert; + } + } + + if ((k & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) { + output_weights[k / WARP_SIZE] = max_val; + } + + if ((max_expert & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) { + wt[max_expert / WARP_SIZE] = -INFINITY; + + ids[ids_offset + k] = max_expert; + if (with_norm) { + wt_sum += max_val; + } + } + } + + if (with_norm) { + wt_sum = subgroupAdd(wt_sum); + const float inv_sum = 1.0f / wt_sum; + + [[unroll]] + for (uint i = 0; i < experts_per_thread; ++i) { + output_weights[i] *= inv_sum; + } + } + + [[unroll]] + for (uint i = 0; i < experts_per_thread; ++i) { + uint idx = i * WARP_SIZE + gl_LocalInvocationID.x; + if (idx < n_expert_used) { + weights[weights_offset + idx] = output_weights[i]; + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 1d04a812a038a..49bf6c764f726 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -920,6 +920,8 @@ void process_shaders() { string_to_spv("ssm_conv_f32", "ssm_conv.comp", {{"A_TYPE", "float"}}); + string_to_spv("topk_moe_f32", "topk_moe.comp", {}); + for (auto &c : compiles) { c.wait(); }