diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 58fee4bfd1296..d1abb3cef0ec4 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -126,6 +126,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, GGML_METAL_KERNEL_TYPE_SILU, GGML_METAL_KERNEL_TYPE_SILU_4, + GGML_METAL_KERNEL_TYPE_ELU, GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, @@ -649,6 +650,7 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction); @@ -968,6 +970,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_ELU: return ggml_is_contiguous(op->src[0]); default: return false; @@ -1589,6 +1592,18 @@ static void ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_UNARY_OP_ELU: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ELU].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; default: { GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op)); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 86fdf1c18cfb6..971f5054bce2a 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -782,6 +782,14 @@ kernel void kernel_silu_4( dst[tpig] = x / (1.0f + exp(-x)); } +kernel void kernel_elu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f); +} + kernel void kernel_sqr( device const float * src0, device float * dst, @@ -2137,20 +2145,34 @@ kernel void kernel_im2col( uint3 tgpg[[threadgroups_per_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { - const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0; - const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1; +// const int64_t IC = tgpg[0]; + const int64_t OH = tgpg[1]; + const int64_t OW = tgpg[2]; - const int32_t offset_dst = - (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW + - (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]); +// const int64_t N = ntg[0]; + const int64_t KH = ntg[1]; + const int64_t KW = ntg[2]; + + const int64_t in = tpitg[0]; + const int64_t ikh = tpitg[1]; + const int64_t ikw = tpitg[2]; + + const int64_t iic = tgpig[0]; + const int64_t ioh = tgpig[1]; + const int64_t iow = tgpig[2]; + + const int64_t iiw = iow*s0 + ikw*d0 - p0; + const int64_t iih = ioh*s1 + ikh*d1 - p1; + + const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*CHW + (iic*(KH*KW) + ikh*KW + ikw); device T * pdst = (device T *) (dst); if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { pdst[offset_dst] = 0.0f; } else { - const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1; - pdst[offset_dst] = x[offset_src + iih * IW + iiw]; + const int64_t offset_src = in*ofs0 + iic*ofs1 + iih*IW + iiw; + pdst[offset_dst] = x[offset_src]; } } @@ -2201,25 +2223,25 @@ kernel void kernel_im2col_ext( uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1] - const int32_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2] + const int64_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2] - const int32_t d = tgpig[0] / CHW; - const int32_t chw = tgpig[0] % CHW; - const int32_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) - const int32_t HW = tgpig[0] % KHW; + const int64_t d = tgpig[0] / CHW; + const int64_t chw = tgpig[0] % CHW; + const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) + const int64_t HW = tgpig[0] % KHW; - const int32_t tpitg_0 = (d * ntg[0]) + tpitg[0]; + const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0]; if (tpitg_0 >= N) { return; } - const int32_t tpitg_1 = HW / KW; - const int32_t tpitg_2 = HW % KW; + const int64_t tpitg_1 = HW / KW; + const int64_t tpitg_2 = HW % KW; - const int32_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0; - const int32_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1; + const int64_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0; + const int64_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1; - const int32_t offset_dst = + const int64_t offset_dst = (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW + (tgpig_0 * KHW + tpitg_1 * KW + tpitg_2); @@ -2228,7 +2250,7 @@ kernel void kernel_im2col_ext( if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { pdst[offset_dst] = 0.0f; } else { - const int32_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1; + const int64_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1; pdst[offset_dst] = x[offset_src + iih * IW + iiw]; } } diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 6ddb71ab133bf..e9bd2dbb05e2c 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -9d0708e863f3aa2fc1eb0b75d433303c30bd0dbc +2884dd72fea8922910fe53387c3d17ab928d3a8e