diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 9527973015245..1af0a02ee08d7 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -655,7 +655,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_SIN: case GGML_OP_COS: case GGML_OP_LOG: - return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32); case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: case GGML_OP_SOFT_MAX: diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index ddc285042d284..7649001148159 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1487,6 +1487,23 @@ kernel void kernel_log_f32_4( dst[tpig] = log(src0[tpig]); } +kernel void kernel_log_f16( + device const half * src0, + device half * dst, + uint tpig [[thread_position_in_grid]]) { + const float x = (float)src0[tpig]; + dst[tpig] = (half)log(x); +} + +kernel void kernel_log_f16_4( + device const half4 * src0, + device half4 * dst, + uint tpig [[thread_position_in_grid]]) { + const half4 xh = src0[tpig]; + float4 xf = float4(xh); + dst[tpig] = half4(log(xf)); +} + kernel void kernel_neg_f32( device const float * src0, device float * dst,