@@ -126,6 +126,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
126126 GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
127127 GGML_METAL_KERNEL_TYPE_SILU,
128128 GGML_METAL_KERNEL_TYPE_SILU_4,
129+ GGML_METAL_KERNEL_TYPE_ELU,
129130 GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
130131 GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
131132 GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
@@ -649,6 +650,7 @@ @implementation GGMLMetalClass
649650 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true );
650651 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SILU, silu, true );
651652 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true );
653+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ELU, elu, true );
652654 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
653655 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
654656 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
@@ -968,6 +970,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
968970 case GGML_UNARY_OP_GELU:
969971 case GGML_UNARY_OP_GELU_QUICK:
970972 case GGML_UNARY_OP_SILU:
973+ case GGML_UNARY_OP_ELU:
971974 return ggml_is_contiguous (op->src [0 ]);
972975 default :
973976 return false ;
@@ -1589,6 +1592,18 @@ static void ggml_metal_encode_node(
15891592
15901593 [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
15911594 } break ;
1595+ case GGML_UNARY_OP_ELU:
1596+ {
1597+ id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ELU].pipeline ;
1598+
1599+ [encoder setComputePipelineState: pipeline];
1600+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1601+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1602+
1603+ const int64_t n = ggml_nelements (dst);
1604+
1605+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
1606+ } break ;
15921607 default :
15931608 {
15941609 GGML_LOG_WARN (" %s : node %3d , op = %8s not implemented\n " , __func__, idx, ggml_op_name (dst->op ));
0 commit comments