@@ -497,6 +497,9 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
497497 GGML_METAL_KERNEL_TYPE_SIN,
498498 GGML_METAL_KERNEL_TYPE_COS,
499499 GGML_METAL_KERNEL_TYPE_NEG,
500+ GGML_METAL_KERNEL_TYPE_REGLU,
501+ GGML_METAL_KERNEL_TYPE_GEGLU,
502+ GGML_METAL_KERNEL_TYPE_SWIGLU,
500503 GGML_METAL_KERNEL_TYPE_SUM_ROWS,
501504 GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
502505 GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
@@ -1453,6 +1456,9 @@ @implementation GGMLMetalClass
14531456 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SIN, sin, true );
14541457 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_COS, cos, true );
14551458 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_NEG, neg, true );
1459+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_REGLU, reglu, true );
1460+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true );
1461+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true );
14561462 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true );
14571463 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true );
14581464 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true );
@@ -1626,6 +1632,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
16261632 default :
16271633 return false ;
16281634 }
1635+ case GGML_OP_GLU:
1636+ switch (ggml_get_glu_op (op)) {
1637+ case GGML_GLU_OP_REGLU:
1638+ case GGML_GLU_OP_GEGLU:
1639+ case GGML_GLU_OP_SWIGLU:
1640+ return ggml_is_contiguous_1 (op->src [0 ]) && op->src [0 ]->type == GGML_TYPE_F32;
1641+ default :
1642+ return false ;
1643+ }
16291644 case GGML_OP_NONE:
16301645 case GGML_OP_RESHAPE:
16311646 case GGML_OP_VIEW:
@@ -2343,6 +2358,43 @@ static bool ggml_metal_encode_node(
23432358 GGML_ABORT (" fatal error" );
23442359 }
23452360 } break ;
2361+ case GGML_OP_GLU:
2362+ {
2363+ GGML_ASSERT (ggml_is_contiguous_1 (src0));
2364+
2365+ id <MTLComputePipelineState > pipeline = nil ;
2366+
2367+ switch (ggml_get_glu_op (node)) {
2368+ case GGML_GLU_OP_REGLU:
2369+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_REGLU].pipeline ;
2370+ break ;
2371+ case GGML_GLU_OP_GEGLU:
2372+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GEGLU].pipeline ;
2373+ break ;
2374+ case GGML_GLU_OP_SWIGLU:
2375+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline ;
2376+ break ;
2377+ default :
2378+ GGML_ABORT (" fatal error" );
2379+ }
2380+
2381+ ggml_metal_kargs_glu args = {
2382+ /* .ne00 =*/ ne00,
2383+ /* .nb01 =*/ nb01,
2384+ /* .nb1 =*/ nb1,
2385+ };
2386+
2387+ [encoder setComputePipelineState: pipeline];
2388+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2389+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
2390+ [encoder setBytes: &args length: sizeof (args) atIndex: 2 ];
2391+
2392+ const int64_t nrows = ggml_nrows (src0);
2393+
2394+ const int32_t nth = MIN ((int ) pipeline.maxTotalThreadsPerThreadgroup , ne00/2 );
2395+
2396+ [encoder dispatchThreadgroups: MTLSizeMake (nrows, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
2397+ } break ;
23462398 case GGML_OP_SQR:
23472399 {
23482400 GGML_ASSERT (ggml_is_contiguous (src0));
@@ -2405,7 +2457,6 @@ static bool ggml_metal_encode_node(
24052457
24062458 id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline ;
24072459
2408-
24092460 ggml_metal_kargs_sum_rows args = {
24102461 /* .ne00 =*/ ne00,
24112462 /* .ne01 =*/ ne01,
0 commit comments