Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ggml/src/ggml-metal/ggml-metal-device.m
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_LOG:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_SOFT_MAX:
Expand Down
17 changes: 17 additions & 0 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1487,6 +1487,23 @@ kernel void kernel_log_f32_4(
dst[tpig] = log(src0[tpig]);
}

kernel void kernel_log_f16(
device const half * src0,
device half * dst,
uint tpig [[thread_position_in_grid]]) {
const float x = (float)src0[tpig];
dst[tpig] = (half)log(x);
}

kernel void kernel_log_f16_4(
device const half4 * src0,
device half4 * dst,
uint tpig [[thread_position_in_grid]]) {
const half4 xh = src0[tpig];
float4 xf = float4(xh);
dst[tpig] = half4(log(xf));
}

kernel void kernel_neg_f32(
device const float * src0,
device float * dst,
Expand Down