@@ -665,7 +665,9 @@ struct vk_op_push_constants {
665665};
666666
667667struct vk_op_glu_push_constants {
668+ uint32_t N;
668669 uint32_t ne00;
670+ uint32_t ne20;
669671 uint32_t mode; // 0: default, 1: swapped, 2: split
670672};
671673
@@ -2761,8 +2763,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
27612763#undef CREATE_UNARY
27622764
27632765#define CREATE_GLU(name) \
2764- ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {1 , 1, 1}, { device->subgroup_size }, 1); \
2765- ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {1 , 1, 1}, { device->subgroup_size }, 1);
2766+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512 , 1, 1}, {}, 1, true ); \
2767+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512 , 1, 1}, {}, 1, true );
27662768
27672769 CREATE_GLU(geglu)
27682770 CREATE_GLU(reglu)
@@ -6867,7 +6869,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
68676869 case GGML_OP_SOFT_MAX_BACK:
68686870 case GGML_OP_SUM_ROWS:
68696871 case GGML_OP_ARGMAX:
6870- case GGML_OP_GLU:
68716872 {
68726873 const uint32_t nr = ggml_nrows(src0);
68736874 if (nr > 262144) {
@@ -6952,6 +6953,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
69526953 case GGML_OP_CONCAT:
69536954 case GGML_OP_UPSCALE:
69546955 case GGML_OP_UNARY:
6956+ case GGML_OP_GLU:
69556957 case GGML_OP_CONV_2D_DW:
69566958 {
69576959 uint32_t ne = ggml_nelements(dst);
@@ -7600,7 +7602,7 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const
76007602
76017603 const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
76027604
7603- ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, { (uint32_t)src0->ne[0], mode }, dryrun);
7605+ ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, { (uint32_t)ggml_nelements(dst), (uint32_t) src0->ne[0], (uint32_t)dst ->ne[0], mode }, dryrun);
76047606}
76057607
76067608static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
0 commit comments