Skip to content

Commit 12b0ad9

Browse files
PABannierggerganov
authored andcommitted
metal : add GGML_UNARY_OP_ELU kernel (ggml/1018)
1 parent 342397d commit 12b0ad9

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
126126
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
127127
GGML_METAL_KERNEL_TYPE_SILU,
128128
GGML_METAL_KERNEL_TYPE_SILU_4,
129+
GGML_METAL_KERNEL_TYPE_ELU,
129130
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
130131
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
131132
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
@@ -649,6 +650,7 @@ @implementation GGMLMetalClass
649650
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
650651
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
651652
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
653+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true);
652654
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
653655
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
654656
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
@@ -968,6 +970,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
968970
case GGML_UNARY_OP_GELU:
969971
case GGML_UNARY_OP_GELU_QUICK:
970972
case GGML_UNARY_OP_SILU:
973+
case GGML_UNARY_OP_ELU:
971974
return ggml_is_contiguous(op->src[0]);
972975
default:
973976
return false;
@@ -1589,6 +1592,18 @@ static void ggml_metal_encode_node(
15891592

15901593
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
15911594
} break;
1595+
case GGML_UNARY_OP_ELU:
1596+
{
1597+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ELU].pipeline;
1598+
1599+
[encoder setComputePipelineState:pipeline];
1600+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1601+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1602+
1603+
const int64_t n = ggml_nelements(dst);
1604+
1605+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1606+
} break;
15921607
default:
15931608
{
15941609
GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,14 @@ kernel void kernel_silu_4(
782782
dst[tpig] = x / (1.0f + exp(-x));
783783
}
784784

785+
kernel void kernel_elu(
786+
device const float * src0,
787+
device float * dst,
788+
uint tpig[[thread_position_in_grid]]) {
789+
device const float & x = src0[tpig];
790+
dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f);
791+
}
792+
785793
kernel void kernel_sqr(
786794
device const float * src0,
787795
device float * dst,

0 commit comments

Comments
 (0)