@@ -659,6 +659,11 @@ struct vk_op_push_constants {
659659 float param2;
660660};
661661
662+ struct vk_op_glu_push_constants {
663+ uint32_t ne00;
664+ uint32_t mode; // 0: default, 1: swapped, 2: split
665+ };
666+
662667struct vk_op_unary_push_constants {
663668 uint32_t ne;
664669 uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
@@ -2733,8 +2738,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
27332738#undef CREATE_UNARY
27342739
27352740#define CREATE_GLU(name) \
2736- ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2 , sizeof(vk_op_push_constants ), {1, 1, 1}, { device->subgroup_size }, 1); \
2737- ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2 , sizeof(vk_op_push_constants ), {1, 1, 1}, { device->subgroup_size }, 1);
2741+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3 , sizeof(vk_op_glu_push_constants ), {1, 1, 1}, { device->subgroup_size }, 1); \
2742+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3 , sizeof(vk_op_glu_push_constants ), {1, 1, 1}, { device->subgroup_size }, 1);
27382743
27392744 CREATE_GLU(geglu)
27402745 CREATE_GLU(reglu)
@@ -6947,7 +6952,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
69476952 }
69486953 }
69496954
6950- if (op == GGML_OP_SOFT_MAX) {
6955+ if (op == GGML_OP_SOFT_MAX || op == GGML_OP_GLU ) {
69516956 // Empty src1 is possible in soft_max, but the shader needs a buffer
69526957 vk_subbuffer subbuf_y;
69536958 if (use_src1) {
@@ -7539,12 +7544,23 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con
75397544 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
75407545}
75417546
7542- static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7543- GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
7547+ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7548+ const bool swapped = (bool)dst->op_params[1];
7549+ const bool split = src1 != nullptr;
7550+
7551+ GGML_ASSERT(ggml_is_contiguous(src0));
7552+
7553+ if (!split) {
7554+ GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
7555+ } else {
7556+ GGML_ASSERT(src0->ne[0] == src1->ne[0]);
7557+ GGML_ASSERT(src0->ne[0] == dst->ne[0]);
7558+ GGML_ASSERT(src0->type == src1->type);
7559+ }
75447560
7545- const uint32_t swapped = (uint32_t)dst->op_params[1] ;
7561+ const uint32_t mode = split ? 2 : (swapped ? 1 : 0) ;
75467562
7547- ggml_vk_op_f32<vk_op_push_constants >(ctx, subctx, src0, nullptr , nullptr, dst, GGML_OP_GLU, { (uint32_t)src0->ne[0], swapped, 0.0f, 0.0f }, dryrun);
7563+ ggml_vk_op_f32<vk_op_glu_push_constants >(ctx, subctx, src0, src1 , nullptr, dst, GGML_OP_GLU, { (uint32_t)src0->ne[0], mode }, dryrun);
75487564}
75497565
75507566static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
@@ -9003,7 +9019,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
90039019 case GGML_GLU_OP_GEGLU:
90049020 case GGML_GLU_OP_REGLU:
90059021 case GGML_GLU_OP_SWIGLU:
9006- ggml_vk_glu(ctx, compute_ctx, src0, node, dryrun);
9022+ ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
90079023 break;
90089024 default:
90099025 return false;
@@ -10725,7 +10741,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
1072510741 GGML_ABORT("fatal error");
1072610742 }
1072710743 } else if (tensor->op == GGML_OP_GLU) {
10728- tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]);
10744+ if (src_clone[1] == nullptr) {
10745+ tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]);
10746+ } else {
10747+ tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]);
10748+ }
1072910749 } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
1073010750 if (src1 == nullptr) {
1073110751 tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
0 commit comments