@@ -514,6 +514,9 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
514514 GGML_METAL_KERNEL_TYPE_SIN,
515515 GGML_METAL_KERNEL_TYPE_COS,
516516 GGML_METAL_KERNEL_TYPE_NEG,
517+ GGML_METAL_KERNEL_TYPE_REGLU,
518+ GGML_METAL_KERNEL_TYPE_GEGLU,
519+ GGML_METAL_KERNEL_TYPE_SWIGLU,
517520 GGML_METAL_KERNEL_TYPE_SUM_ROWS,
518521 GGML_METAL_KERNEL_TYPE_MEAN,
519522 GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
@@ -1478,6 +1481,9 @@ @implementation GGMLMetalClass
14781481 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SIN, sin, true );
14791482 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_COS, cos, true );
14801483 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_NEG, neg, true );
1484+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_REGLU, reglu, true );
1485+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true );
1486+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true );
14811487 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true );
14821488 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MEAN, mean, true );
14831489 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true );
@@ -1652,6 +1658,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
16521658 default :
16531659 return false ;
16541660 }
1661+ case GGML_OP_GLU:
1662+ switch (ggml_get_glu_op (op)) {
1663+ case GGML_GLU_OP_REGLU:
1664+ case GGML_GLU_OP_GEGLU:
1665+ case GGML_GLU_OP_SWIGLU:
1666+ return ggml_is_contiguous_1 (op->src [0 ]) && op->src [0 ]->type == GGML_TYPE_F32;
1667+ default :
1668+ return false ;
1669+ }
16551670 case GGML_OP_NONE:
16561671 case GGML_OP_RESHAPE:
16571672 case GGML_OP_VIEW:
@@ -2370,6 +2385,43 @@ static bool ggml_metal_encode_node(
23702385 GGML_ABORT (" fatal error" );
23712386 }
23722387 } break ;
2388+ case GGML_OP_GLU:
2389+ {
2390+ GGML_ASSERT (ggml_is_contiguous_1 (src0));
2391+
2392+ id <MTLComputePipelineState > pipeline = nil ;
2393+
2394+ switch (ggml_get_glu_op (node)) {
2395+ case GGML_GLU_OP_REGLU:
2396+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_REGLU].pipeline ;
2397+ break ;
2398+ case GGML_GLU_OP_GEGLU:
2399+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GEGLU].pipeline ;
2400+ break ;
2401+ case GGML_GLU_OP_SWIGLU:
2402+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline ;
2403+ break ;
2404+ default :
2405+ GGML_ABORT (" fatal error" );
2406+ }
2407+
2408+ ggml_metal_kargs_glu args = {
2409+ /* .ne00 =*/ ne00,
2410+ /* .nb01 =*/ nb01,
2411+ /* .nb1 =*/ nb1,
2412+ };
2413+
2414+ [encoder setComputePipelineState: pipeline];
2415+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2416+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
2417+ [encoder setBytes: &args length: sizeof (args) atIndex: 2 ];
2418+
2419+ const int64_t nrows = ggml_nrows (src0);
2420+
2421+ const int32_t nth = MIN ((int ) pipeline.maxTotalThreadsPerThreadgroup , ne00/2 );
2422+
2423+ [encoder dispatchThreadgroups: MTLSizeMake (nrows, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
2424+ } break ;
23732425 case GGML_OP_SQR:
23742426 {
23752427 GGML_ASSERT (ggml_is_contiguous (src0));
0 commit comments