@@ -173,6 +173,12 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
173173 GGML_METAL_KERNEL_TYPE_SILU,
174174 GGML_METAL_KERNEL_TYPE_SILU_4,
175175 GGML_METAL_KERNEL_TYPE_ELU,
176+ GGML_METAL_KERNEL_TYPE_ABS,
177+ GGML_METAL_KERNEL_TYPE_SGN,
178+ GGML_METAL_KERNEL_TYPE_STEP,
179+ GGML_METAL_KERNEL_TYPE_HARDSWISH,
180+ GGML_METAL_KERNEL_TYPE_HARDSIGMOID,
181+ GGML_METAL_KERNEL_TYPE_EXP,
176182 GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
177183 GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
178184 GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
@@ -1155,6 +1161,12 @@ @implementation GGMLMetalClass
11551161 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SILU, silu, true );
11561162 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true );
11571163 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ELU, elu, true );
1164+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ABS, abs, true );
1165+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SGN, sgn, true );
1166+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_STEP, step, true );
1167+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_HARDSWISH, hardswish, true );
1168+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_HARDSIGMOID, hardsigmoid, true );
1169+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_EXP, exp, true );
11581170 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
11591171 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
11601172 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
@@ -1688,6 +1700,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
16881700 case GGML_UNARY_OP_SILU:
16891701 case GGML_UNARY_OP_ELU:
16901702 case GGML_UNARY_OP_NEG:
1703+ case GGML_UNARY_OP_ABS:
1704+ case GGML_UNARY_OP_SGN:
1705+ case GGML_UNARY_OP_STEP:
1706+ case GGML_UNARY_OP_HARDSWISH:
1707+ case GGML_UNARY_OP_HARDSIGMOID:
1708+ case GGML_UNARY_OP_EXP:
16911709 return ggml_is_contiguous (op->src [0 ]) && op->src [0 ]->type == GGML_TYPE_F32;
16921710 default :
16931711 return false ;
@@ -2439,6 +2457,78 @@ static bool ggml_metal_encode_node(
24392457
24402458 [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
24412459 } break ;
2460+ case GGML_UNARY_OP_ABS:
2461+ {
2462+ id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ABS].pipeline ;
2463+
2464+ [encoder setComputePipelineState: pipeline];
2465+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2466+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
2467+
2468+ const int64_t n = ggml_nelements (dst);
2469+
2470+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
2471+ } break ;
2472+ case GGML_UNARY_OP_SGN:
2473+ {
2474+ id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SGN].pipeline ;
2475+
2476+ [encoder setComputePipelineState: pipeline];
2477+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2478+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
2479+
2480+ const int64_t n = ggml_nelements (dst);
2481+
2482+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
2483+ } break ;
2484+ case GGML_UNARY_OP_STEP:
2485+ {
2486+ id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_STEP].pipeline ;
2487+
2488+ [encoder setComputePipelineState: pipeline];
2489+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2490+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
2491+
2492+ const int64_t n = ggml_nelements (dst);
2493+
2494+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
2495+ } break ;
2496+ case GGML_UNARY_OP_HARDSWISH:
2497+ {
2498+ id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_HARDSWISH].pipeline ;
2499+
2500+ [encoder setComputePipelineState: pipeline];
2501+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2502+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
2503+
2504+ const int64_t n = ggml_nelements (dst);
2505+
2506+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
2507+ } break ;
2508+ case GGML_UNARY_OP_HARDSIGMOID:
2509+ {
2510+ id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_HARDSIGMOID].pipeline ;
2511+
2512+ [encoder setComputePipelineState: pipeline];
2513+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2514+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
2515+
2516+ const int64_t n = ggml_nelements (dst);
2517+
2518+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
2519+ } break ;
2520+ case GGML_UNARY_OP_EXP:
2521+ {
2522+ id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_EXP].pipeline ;
2523+
2524+ [encoder setComputePipelineState: pipeline];
2525+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2526+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
2527+
2528+ const int64_t n = ggml_nelements (dst);
2529+
2530+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
2531+ } break ;
24422532 default :
24432533 {
24442534 GGML_LOG_WARN (" %s : node %3d , op = %8s not implemented\n " , __func__, idx, ggml_op_name (dst->op ));
0 commit comments