@@ -126,6 +126,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
126
126
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
127
127
GGML_METAL_KERNEL_TYPE_SILU,
128
128
GGML_METAL_KERNEL_TYPE_SILU_4,
129
+ GGML_METAL_KERNEL_TYPE_ELU,
129
130
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
130
131
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
131
132
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
@@ -649,6 +650,7 @@ @implementation GGMLMetalClass
649
650
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true );
650
651
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SILU, silu, true );
651
652
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true );
653
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ELU, elu, true );
652
654
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
653
655
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
654
656
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
@@ -968,6 +970,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
968
970
case GGML_UNARY_OP_GELU:
969
971
case GGML_UNARY_OP_GELU_QUICK:
970
972
case GGML_UNARY_OP_SILU:
973
+ case GGML_UNARY_OP_ELU:
971
974
return ggml_is_contiguous (op->src [0 ]);
972
975
default :
973
976
return false ;
@@ -1589,6 +1592,18 @@ static void ggml_metal_encode_node(
1589
1592
1590
1593
[encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
1591
1594
} break ;
1595
+ case GGML_UNARY_OP_ELU:
1596
+ {
1597
+ id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ELU].pipeline ;
1598
+
1599
+ [encoder setComputePipelineState: pipeline];
1600
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1601
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1602
+
1603
+ const int64_t n = ggml_nelements (dst);
1604
+
1605
+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
1606
+ } break ;
1592
1607
default :
1593
1608
{
1594
1609
GGML_LOG_WARN (" %s : node %3d , op = %8s not implemented\n " , __func__, idx, ggml_op_name (dst->op ));
0 commit comments