@@ -310,6 +310,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
310310 GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32,
311311 GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
312312 GGML_METAL_KERNEL_TYPE_PAD_F32,
313+ GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32,
313314 GGML_METAL_KERNEL_TYPE_ARANGE_F32,
314315 GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
315316 GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
@@ -877,6 +878,7 @@ @implementation GGMLMetalClass
877878 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true );
878879 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true );
879880 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true );
881+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, pad_reflect_1d_f32, true );
880882 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true );
881883 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true );
882884 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true );
@@ -1099,6 +1101,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
10991101 case GGML_OP_POOL_2D:
11001102 case GGML_OP_UPSCALE:
11011103 case GGML_OP_PAD:
1104+ case GGML_OP_PAD_REFLECT_1D:
11021105 case GGML_OP_ARANGE:
11031106 case GGML_OP_TIMESTEP_EMBEDDING:
11041107 case GGML_OP_ARGSORT:
@@ -3258,6 +3261,38 @@ static void ggml_metal_encode_node(
32583261
32593262 const int nth = MIN (1024 , ne0);
32603263
3264+ [encoder dispatchThreadgroups: MTLSizeMake (ne1, ne2, ne3) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
3265+ } break ;
3266+ case GGML_OP_PAD_REFLECT_1D:
3267+ {
3268+ GGML_ASSERT (src0->type == GGML_TYPE_F32);
3269+
3270+ const int32_t p0 = ((const int32_t *)(dst->op_params ))[0 ];
3271+ const int32_t p1 = ((const int32_t *)(dst->op_params ))[1 ];
3272+
3273+ id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline ;
3274+
3275+ [encoder setComputePipelineState: pipeline];
3276+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
3277+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
3278+ [encoder setBytes: &ne00 length: sizeof (ne00) atIndex: 2 ];
3279+ [encoder setBytes: &ne01 length: sizeof (ne01) atIndex: 3 ];
3280+ [encoder setBytes: &ne02 length: sizeof (ne02) atIndex: 4 ];
3281+ [encoder setBytes: &ne03 length: sizeof (ne03) atIndex: 5 ];
3282+ [encoder setBytes: &ne0 length: sizeof (ne0) atIndex: 6 ];
3283+ [encoder setBytes: &nb00 length: sizeof (nb00) atIndex: 7 ];
3284+ [encoder setBytes: &nb01 length: sizeof (nb01) atIndex: 8 ];
3285+ [encoder setBytes: &nb02 length: sizeof (nb02) atIndex: 9 ];
3286+ [encoder setBytes: &nb03 length: sizeof (nb03) atIndex: 10 ];
3287+ [encoder setBytes: &nb0 length: sizeof (nb0) atIndex: 11 ];
3288+ [encoder setBytes: &nb1 length: sizeof (nb1) atIndex: 12 ];
3289+ [encoder setBytes: &nb2 length: sizeof (nb2) atIndex: 13 ];
3290+ [encoder setBytes: &nb3 length: sizeof (nb3) atIndex: 14 ];
3291+ [encoder setBytes: &p0 length: sizeof (p0) atIndex: 15 ];
3292+ [encoder setBytes: &p1 length: sizeof (p1) atIndex: 16 ];
3293+
3294+ const int nth = MIN (1024 , ne0);
3295+
32613296 [encoder dispatchThreadgroups: MTLSizeMake (ne1, ne2, ne3) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
32623297 } break ;
32633298 case GGML_OP_ARANGE:
0 commit comments