Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions ggml/src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
GGML_METAL_KERNEL_TYPE_SILU,
GGML_METAL_KERNEL_TYPE_SILU_4,
GGML_METAL_KERNEL_TYPE_ELU,
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
Expand Down Expand Up @@ -649,6 +650,7 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
Expand Down Expand Up @@ -968,6 +970,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_ELU:
return ggml_is_contiguous(op->src[0]);
default:
return false;
Expand Down Expand Up @@ -1589,6 +1592,18 @@ static void ggml_metal_encode_node(

[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_UNARY_OP_ELU:
{
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ELU].pipeline;

[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];

const int64_t n = ggml_nelements(dst);

[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
default:
{
GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
Expand Down
60 changes: 41 additions & 19 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,14 @@ kernel void kernel_silu_4(
dst[tpig] = x / (1.0f + exp(-x));
}

kernel void kernel_elu(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
device const float & x = src0[tpig];
dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f);
}

kernel void kernel_sqr(
device const float * src0,
device float * dst,
Expand Down Expand Up @@ -2137,20 +2145,34 @@ kernel void kernel_im2col(
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;
// const int64_t IC = tgpg[0];
const int64_t OH = tgpg[1];
const int64_t OW = tgpg[2];

const int32_t offset_dst =
(tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
(tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
// const int64_t N = ntg[0];
const int64_t KH = ntg[1];
const int64_t KW = ntg[2];

const int64_t in = tpitg[0];
const int64_t ikh = tpitg[1];
const int64_t ikw = tpitg[2];

const int64_t iic = tgpig[0];
const int64_t ioh = tgpig[1];
const int64_t iow = tgpig[2];

const int64_t iiw = iow*s0 + ikw*d0 - p0;
const int64_t iih = ioh*s1 + ikh*d1 - p1;

const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*CHW + (iic*(KH*KW) + ikh*KW + ikw);

device T * pdst = (device T *) (dst);

if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
pdst[offset_dst] = 0.0f;
} else {
const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
pdst[offset_dst] = x[offset_src + iih * IW + iiw];
const int64_t offset_src = in*ofs0 + iic*ofs1 + iih*IW + iiw;
pdst[offset_dst] = x[offset_src];
}
}

Expand Down Expand Up @@ -2201,25 +2223,25 @@ kernel void kernel_im2col_ext(
uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
const int32_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2]
const int64_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2]

const int32_t d = tgpig[0] / CHW;
const int32_t chw = tgpig[0] % CHW;
const int32_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
const int32_t HW = tgpig[0] % KHW;
const int64_t d = tgpig[0] / CHW;
const int64_t chw = tgpig[0] % CHW;
const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
const int64_t HW = tgpig[0] % KHW;

const int32_t tpitg_0 = (d * ntg[0]) + tpitg[0];
const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
if (tpitg_0 >= N) {
return;
}

const int32_t tpitg_1 = HW / KW;
const int32_t tpitg_2 = HW % KW;
const int64_t tpitg_1 = HW / KW;
const int64_t tpitg_2 = HW % KW;

const int32_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0;
const int32_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1;
const int64_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0;
const int64_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1;

const int32_t offset_dst =
const int64_t offset_dst =
(tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
(tgpig_0 * KHW + tpitg_1 * KW + tpitg_2);

Expand All @@ -2228,7 +2250,7 @@ kernel void kernel_im2col_ext(
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
pdst[offset_dst] = 0.0f;
} else {
const int32_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
const int64_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
pdst[offset_dst] = x[offset_src + iih * IW + iiw];
}
}
Expand Down
2 changes: 1 addition & 1 deletion scripts/sync-ggml.last
Original file line number Diff line number Diff line change
@@ -1 +1 @@
9d0708e863f3aa2fc1eb0b75d433303c30bd0dbc
2884dd72fea8922910fe53387c3d17ab928d3a8e
Loading