@@ -241,6 +241,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
241241 GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
242242 GGML_METAL_KERNEL_TYPE_IM2COL_F16,
243243 GGML_METAL_KERNEL_TYPE_IM2COL_F32,
244+ GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
245+ GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
244246 GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
245247 GGML_METAL_KERNEL_TYPE_PAD_F32,
246248 GGML_METAL_KERNEL_TYPE_ARANGE_F32,
@@ -272,6 +274,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
272274 GGML_METAL_KERNEL_TYPE_SIN,
273275 GGML_METAL_KERNEL_TYPE_COS,
274276 GGML_METAL_KERNEL_TYPE_SUM_ROWS,
277+ GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
278+ GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
275279
276280 GGML_METAL_KERNEL_TYPE_COUNT
277281};
@@ -685,6 +689,8 @@ @implementation GGMLMetalClass
685689 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true );
686690 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true );
687691 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true );
692+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true );
693+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true );
688694 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true );
689695 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true );
690696 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true );
@@ -716,6 +722,8 @@ @implementation GGMLMetalClass
716722 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SIN, sin, true );
717723 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_COS, cos, true );
718724 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true );
725+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true );
726+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true );
719727 }
720728
721729 [metal_library release ];
@@ -844,8 +852,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
844852 case GGML_OP_IM2COL:
845853 return op->src [0 ]->type == GGML_TYPE_F16;
846854 case GGML_OP_POOL_1D:
847- case GGML_OP_POOL_2D:
848855 return false ;
856+ case GGML_OP_POOL_2D:
849857 case GGML_OP_UPSCALE:
850858 case GGML_OP_PAD:
851859 case GGML_OP_ARANGE:
@@ -2545,6 +2553,8 @@ static void ggml_metal_encode_node(
25452553 } break ;
25462554 case GGML_OP_IM2COL:
25472555 {
2556+ GGML_ASSERT (ggml_is_contiguous (src0));
2557+ GGML_ASSERT (ggml_is_contiguous (src1));
25482558 GGML_ASSERT (src0->type == GGML_TYPE_F16);
25492559 GGML_ASSERT (src1->type == GGML_TYPE_F32);
25502560 GGML_ASSERT ( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
@@ -2574,30 +2584,54 @@ static void ggml_metal_encode_node(
25742584 const int32_t ofs0 = src1->nb [is_2D ? 3 : 2 ] / 4 ;
25752585 const int32_t ofs1 = src1->nb [is_2D ? 2 : 1 ] / 4 ;
25762586
2577- id <MTLComputePipelineState > pipeline = nil ;
2587+ id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline ;
2588+
2589+ const bool is_gt_mttpt = ((size_t )(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup ;
25782590
25792591 switch (dst->type ) {
2580- case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline ; break ;
2581- case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline ; break ;
2592+ case GGML_TYPE_F32: {
2593+ pipeline = (is_gt_mttpt ?
2594+ ctx->kernels [GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32].pipeline
2595+ :
2596+ ctx->kernels [GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline );
2597+ } break ;
2598+ case GGML_TYPE_F16: {
2599+ pipeline = (is_gt_mttpt ?
2600+ ctx->kernels [GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16].pipeline
2601+ :
2602+ ctx->kernels [GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline );
2603+ } break ;
25822604 default : GGML_ABORT (" fatal error" );
25832605 };
25842606
25852607 [encoder setComputePipelineState: pipeline];
2586- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 0 ];
2587- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
2588- [encoder setBytes: &ofs0 length: sizeof ( int32_t ) atIndex: 2 ];
2589- [encoder setBytes: &ofs1 length: sizeof ( int32_t ) atIndex: 3 ];
2590- [encoder setBytes: &IW length: sizeof ( int32_t ) atIndex: 4 ];
2591- [encoder setBytes: &IH length: sizeof ( int32_t ) atIndex: 5 ];
2592- [encoder setBytes: &CHW length: sizeof ( int32_t ) atIndex: 6 ];
2593- [encoder setBytes: &s0 length: sizeof ( int32_t ) atIndex: 7 ];
2594- [encoder setBytes: &s1 length: sizeof ( int32_t ) atIndex: 8 ];
2595- [encoder setBytes: &p0 length: sizeof ( int32_t ) atIndex: 9 ];
2596- [encoder setBytes: &p1 length: sizeof ( int32_t ) atIndex: 10 ];
2597- [encoder setBytes: &d0 length: sizeof ( int32_t ) atIndex: 11 ];
2598- [encoder setBytes: &d1 length: sizeof ( int32_t ) atIndex: 12 ];
2599-
2600- [encoder dispatchThreadgroups: MTLSizeMake (IC, OH, OW) threadsPerThreadgroup: MTLSizeMake (N, KH, KW)];
2608+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 0 ];
2609+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
2610+ [encoder setBytes: &ofs0 length: sizeof (int32_t ) atIndex: 2 ];
2611+ [encoder setBytes: &ofs1 length: sizeof (int32_t ) atIndex: 3 ];
2612+ [encoder setBytes: &IW length: sizeof (int32_t ) atIndex: 4 ];
2613+ [encoder setBytes: &IH length: sizeof (int32_t ) atIndex: 5 ];
2614+ [encoder setBytes: &CHW length: sizeof (int32_t ) atIndex: 6 ];
2615+ [encoder setBytes: &s0 length: sizeof (int32_t ) atIndex: 7 ];
2616+ [encoder setBytes: &s1 length: sizeof (int32_t ) atIndex: 8 ];
2617+ [encoder setBytes: &p0 length: sizeof (int32_t ) atIndex: 9 ];
2618+ [encoder setBytes: &p1 length: sizeof (int32_t ) atIndex: 10 ];
2619+ [encoder setBytes: &d0 length: sizeof (int32_t ) atIndex: 11 ];
2620+ [encoder setBytes: &d1 length: sizeof (int32_t ) atIndex: 12 ];
2621+
2622+ if (is_gt_mttpt) {
2623+ [encoder setBytes: &N length: sizeof (int32_t ) atIndex: 13 ];
2624+ [encoder setBytes: &KH length: sizeof (int32_t ) atIndex: 14 ];
2625+ [encoder setBytes: &KW length: sizeof (int32_t ) atIndex: 15 ];
2626+
2627+ const uint64_t n_threads = MIN (pipeline.maxTotalThreadsPerThreadgroup , (uint64_t )N);
2628+
2629+ const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0 );
2630+
2631+ [encoder dispatchThreadgroups: MTLSizeMake (quotient * CHW, OH, OW) threadsPerThreadgroup: MTLSizeMake (n_threads, 1 , 1 )];
2632+ } else {
2633+ [encoder dispatchThreadgroups: MTLSizeMake (IC, OH, OW) threadsPerThreadgroup: MTLSizeMake (N, KH, KW)];
2634+ }
26012635 } break ;
26022636 case GGML_OP_UPSCALE:
26032637 {
@@ -3001,6 +3035,64 @@ static void ggml_metal_encode_node(
30013035
30023036 [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
30033037 } break ;
3038+ case GGML_OP_POOL_2D:
3039+ {
3040+ GGML_ASSERT (ggml_is_contiguous (src0));
3041+ GGML_ASSERT (src0t == GGML_TYPE_F32 && src0t == dstt);
3042+
3043+ const int32_t * opts = dst->op_params ;
3044+ enum ggml_op_pool op = opts[0 ];
3045+
3046+ id <MTLComputePipelineState > pipeline = nil ;
3047+ switch (src0t) {
3048+ case GGML_TYPE_F32: {
3049+ switch (op) {
3050+ case GGML_OP_POOL_AVG:
3051+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline ; break ;
3052+ case GGML_OP_POOL_MAX:
3053+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline ; break ;
3054+ default : GGML_ASSERT (false && " not implemented" );
3055+ }
3056+ } break ;
3057+ default : GGML_ASSERT (false && " not implemented" );
3058+ }
3059+
3060+ const int32_t k0 = opts[1 ];
3061+ const int32_t k1 = opts[2 ];
3062+ const int32_t s0 = opts[3 ];
3063+ const int32_t s1 = opts[4 ];
3064+ const int32_t p0 = opts[5 ];
3065+ const int32_t p1 = opts[6 ];
3066+
3067+ const int64_t IH = src0->ne [1 ];
3068+ const int64_t IW = src0->ne [0 ];
3069+
3070+ const int64_t N = dst->ne [3 ];
3071+ const int64_t OC = dst->ne [2 ];
3072+ const int64_t OH = dst->ne [1 ];
3073+ const int64_t OW = dst->ne [0 ];
3074+
3075+ const int64_t parallel_elements = N * OC * OH * OW;
3076+ const int64_t n_threads = MIN ((int64_t )[pipeline maxTotalThreadsPerThreadgroup ], parallel_elements);
3077+ const int64_t n_tg = (parallel_elements + n_threads - 1 ) / n_threads;
3078+
3079+ [encoder setComputePipelineState: pipeline];
3080+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
3081+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
3082+ [encoder setBytes: &k0 length: sizeof (int32_t ) atIndex: 2 ];
3083+ [encoder setBytes: &k1 length: sizeof (int32_t ) atIndex: 3 ];
3084+ [encoder setBytes: &s0 length: sizeof (int32_t ) atIndex: 4 ];
3085+ [encoder setBytes: &s1 length: sizeof (int32_t ) atIndex: 5 ];
3086+ [encoder setBytes: &p0 length: sizeof (int32_t ) atIndex: 6 ];
3087+ [encoder setBytes: &p1 length: sizeof (int32_t ) atIndex: 7 ];
3088+ [encoder setBytes: &IH length: sizeof (int64_t ) atIndex: 8 ];
3089+ [encoder setBytes: &IW length: sizeof (int64_t ) atIndex: 9 ];
3090+ [encoder setBytes: &OH length: sizeof (int64_t ) atIndex: 10 ];
3091+ [encoder setBytes: &OW length: sizeof (int64_t ) atIndex: 11 ];
3092+ [encoder setBytes: ¶llel_elements length: sizeof (int64_t ) atIndex: 12 ];
3093+
3094+ [encoder dispatchThreadgroups: MTLSizeMake (n_tg, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (n_threads, 1 , 1 )];
3095+ } break ;
30043096 default :
30053097 {
30063098 GGML_LOG_ERROR (" %s : error: node %3d , op = %8s not implemented\n " , __func__, idx, ggml_op_name (dst->op ));
0 commit comments