@@ -241,6 +241,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
241241    GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
242242    GGML_METAL_KERNEL_TYPE_IM2COL_F16,
243243    GGML_METAL_KERNEL_TYPE_IM2COL_F32,
244+     GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
245+     GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
244246    GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
245247    GGML_METAL_KERNEL_TYPE_PAD_F32,
246248    GGML_METAL_KERNEL_TYPE_ARANGE_F32,
@@ -272,6 +274,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
272274    GGML_METAL_KERNEL_TYPE_SIN,
273275    GGML_METAL_KERNEL_TYPE_COS,
274276    GGML_METAL_KERNEL_TYPE_SUM_ROWS,
277+     GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
278+     GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
275279
276280    GGML_METAL_KERNEL_TYPE_COUNT
277281};
@@ -685,6 +689,8 @@ @implementation GGMLMetalClass
685689        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,                 rope_neox_f16,                  true );
686690        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_IM2COL_F16,                    im2col_f16,                     true );
687691        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_IM2COL_F32,                    im2col_f32,                     true );
692+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,                im2col_ext_f16,                 true );
693+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,                im2col_ext_f32,                 true );
688694        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_UPSCALE_F32,                   upscale_f32,                    true );
689695        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_PAD_F32,                       pad_f32,                        true );
690696        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,        timestep_embedding_f32,         true );
@@ -716,6 +722,8 @@ @implementation GGMLMetalClass
716722        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SIN,                           sin,                            true );
717723        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_COS,                           cos,                            true );
718724        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUM_ROWS,                      sum_rows,                       true );
725+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,               pool_2d_avg_f32,                true );
726+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,               pool_2d_max_f32,                true );
719727    }
720728
721729    [metal_library release ];
@@ -844,8 +852,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
844852        case  GGML_OP_IM2COL:
845853            return  op->src [0 ]->type  == GGML_TYPE_F16;
846854        case  GGML_OP_POOL_1D:
847-         case  GGML_OP_POOL_2D:
848855            return  false ;
856+         case  GGML_OP_POOL_2D:
849857        case  GGML_OP_UPSCALE:
850858        case  GGML_OP_PAD:
851859        case  GGML_OP_ARANGE:
@@ -2545,6 +2553,8 @@ static void ggml_metal_encode_node(
25452553            } break ;
25462554        case  GGML_OP_IM2COL:
25472555            {
2556+                 GGML_ASSERT (ggml_is_contiguous (src0));
2557+                 GGML_ASSERT (ggml_is_contiguous (src1));
25482558                GGML_ASSERT (src0->type  == GGML_TYPE_F16);
25492559                GGML_ASSERT (src1->type  == GGML_TYPE_F32);
25502560                GGML_ASSERT ( dst->type  == GGML_TYPE_F16 || dst->type  == GGML_TYPE_F32);
@@ -2574,30 +2584,54 @@ static void ggml_metal_encode_node(
25742584                const  int32_t  ofs0 = src1->nb [is_2D ? 3  : 2 ] / 4 ;
25752585                const  int32_t  ofs1 = src1->nb [is_2D ? 2  : 1 ] / 4 ;
25762586
2577-                 id <MTLComputePipelineState > pipeline = nil ;
2587+                 id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline ;
2588+ 
2589+                 const  bool  is_gt_mttpt = ((size_t )(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup ;
25782590
25792591                switch  (dst->type ) {
2580-                     case  GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline ; break ;
2581-                     case  GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline ; break ;
2592+                     case  GGML_TYPE_F32: {
2593+                         pipeline = (is_gt_mttpt ?
2594+                                     ctx->kernels [GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32].pipeline 
2595+                                     :
2596+                                     ctx->kernels [GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline );
2597+                     } break ;
2598+                     case  GGML_TYPE_F16: {
2599+                         pipeline = (is_gt_mttpt ?
2600+                                     ctx->kernels [GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16].pipeline 
2601+                                     :
2602+                                     ctx->kernels [GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline );
2603+                     } break ;
25822604                    default : GGML_ABORT (" fatal error"  );
25832605                };
25842606
25852607                [encoder setComputePipelineState: pipeline];
2586-                 [encoder setBuffer: id_src1 offset: offs_src1        atIndex: 0 ];
2587-                 [encoder setBuffer: id_dst  offset: offs_dst         atIndex: 1 ];
2588-                 [encoder setBytes: &ofs0    length: sizeof ( int32_t ) atIndex: 2 ];
2589-                 [encoder setBytes: &ofs1    length: sizeof ( int32_t ) atIndex: 3 ];
2590-                 [encoder setBytes: &IW      length: sizeof ( int32_t ) atIndex: 4 ];
2591-                 [encoder setBytes: &IH      length: sizeof ( int32_t ) atIndex: 5 ];
2592-                 [encoder setBytes: &CHW     length: sizeof ( int32_t ) atIndex: 6 ];
2593-                 [encoder setBytes: &s0      length: sizeof ( int32_t ) atIndex: 7 ];
2594-                 [encoder setBytes: &s1      length: sizeof ( int32_t ) atIndex: 8 ];
2595-                 [encoder setBytes: &p0      length: sizeof ( int32_t ) atIndex: 9 ];
2596-                 [encoder setBytes: &p1      length: sizeof ( int32_t ) atIndex: 10 ];
2597-                 [encoder setBytes: &d0      length: sizeof ( int32_t ) atIndex: 11 ];
2598-                 [encoder setBytes: &d1      length: sizeof ( int32_t ) atIndex: 12 ];
2599- 
2600-                 [encoder dispatchThreadgroups: MTLSizeMake (IC, OH, OW) threadsPerThreadgroup: MTLSizeMake (N, KH, KW)];
2608+                 [encoder setBuffer: id_src1 offset: offs_src1       atIndex: 0 ];
2609+                 [encoder setBuffer: id_dst  offset: offs_dst        atIndex: 1 ];
2610+                 [encoder setBytes: &ofs0    length: sizeof (int32_t ) atIndex: 2 ];
2611+                 [encoder setBytes: &ofs1    length: sizeof (int32_t ) atIndex: 3 ];
2612+                 [encoder setBytes: &IW      length: sizeof (int32_t ) atIndex: 4 ];
2613+                 [encoder setBytes: &IH      length: sizeof (int32_t ) atIndex: 5 ];
2614+                 [encoder setBytes: &CHW     length: sizeof (int32_t ) atIndex: 6 ];
2615+                 [encoder setBytes: &s0      length: sizeof (int32_t ) atIndex: 7 ];
2616+                 [encoder setBytes: &s1      length: sizeof (int32_t ) atIndex: 8 ];
2617+                 [encoder setBytes: &p0      length: sizeof (int32_t ) atIndex: 9 ];
2618+                 [encoder setBytes: &p1      length: sizeof (int32_t ) atIndex: 10 ];
2619+                 [encoder setBytes: &d0      length: sizeof (int32_t ) atIndex: 11 ];
2620+                 [encoder setBytes: &d1      length: sizeof (int32_t ) atIndex: 12 ];
2621+ 
2622+                 if  (is_gt_mttpt) {
2623+                     [encoder setBytes: &N   length: sizeof (int32_t ) atIndex: 13 ];
2624+                     [encoder setBytes: &KH  length: sizeof (int32_t ) atIndex: 14 ];
2625+                     [encoder setBytes: &KW  length: sizeof (int32_t ) atIndex: 15 ];
2626+ 
2627+                     const  uint64_t  n_threads = MIN (pipeline.maxTotalThreadsPerThreadgroup , (uint64_t )N);
2628+ 
2629+                     const  int64_t   quotient  = N / n_threads + (N % n_threads > 0  ? 1  : 0 );
2630+ 
2631+                     [encoder dispatchThreadgroups: MTLSizeMake (quotient * CHW, OH, OW) threadsPerThreadgroup: MTLSizeMake (n_threads, 1 , 1 )];
2632+                 } else  {
2633+                     [encoder dispatchThreadgroups: MTLSizeMake (IC, OH, OW) threadsPerThreadgroup: MTLSizeMake (N, KH, KW)];
2634+                 }
26012635            } break ;
26022636        case  GGML_OP_UPSCALE:
26032637            {
@@ -3001,6 +3035,64 @@ static void ggml_metal_encode_node(
30013035
30023036                [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
30033037            } break ;
3038+         case  GGML_OP_POOL_2D:
3039+             {
3040+                 GGML_ASSERT (ggml_is_contiguous (src0));
3041+                 GGML_ASSERT (src0t == GGML_TYPE_F32 && src0t == dstt);
3042+ 
3043+                 const  int32_t  * opts = dst->op_params ;
3044+                 enum  ggml_op_pool op = opts[0 ];
3045+ 
3046+                 id <MTLComputePipelineState > pipeline = nil ;
3047+                 switch  (src0t) {
3048+                     case  GGML_TYPE_F32: {
3049+                         switch (op) {
3050+                             case  GGML_OP_POOL_AVG:
3051+                                 pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline ; break ;
3052+                             case  GGML_OP_POOL_MAX:
3053+                                 pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline ; break ;
3054+                             default : GGML_ASSERT (false  && " not implemented"  );
3055+                         }
3056+                     } break ;
3057+                     default : GGML_ASSERT (false  && " not implemented"  );
3058+                 }
3059+ 
3060+                 const  int32_t  k0 = opts[1 ];
3061+                 const  int32_t  k1 = opts[2 ];
3062+                 const  int32_t  s0 = opts[3 ];
3063+                 const  int32_t  s1 = opts[4 ];
3064+                 const  int32_t  p0 = opts[5 ];
3065+                 const  int32_t  p1 = opts[6 ];
3066+ 
3067+                 const  int64_t  IH = src0->ne [1 ];
3068+                 const  int64_t  IW = src0->ne [0 ];
3069+ 
3070+                 const  int64_t  N  = dst->ne [3 ];
3071+                 const  int64_t  OC = dst->ne [2 ];
3072+                 const  int64_t  OH = dst->ne [1 ];
3073+                 const  int64_t  OW = dst->ne [0 ];
3074+ 
3075+                 const  int64_t  parallel_elements = N * OC * OH * OW;
3076+                 const  int64_t  n_threads = MIN ((int64_t )[pipeline maxTotalThreadsPerThreadgroup ], parallel_elements);
3077+                 const  int64_t  n_tg = (parallel_elements + n_threads - 1 ) / n_threads;
3078+ 
3079+                 [encoder setComputePipelineState: pipeline];
3080+                 [encoder setBuffer: id_src0 offset: offs_src0       atIndex: 0 ];
3081+                 [encoder setBuffer: id_dst  offset: offs_dst        atIndex: 1 ];
3082+                 [encoder setBytes: &k0      length: sizeof (int32_t ) atIndex: 2 ];
3083+                 [encoder setBytes: &k1      length: sizeof (int32_t ) atIndex: 3 ];
3084+                 [encoder setBytes: &s0      length: sizeof (int32_t ) atIndex: 4 ];
3085+                 [encoder setBytes: &s1      length: sizeof (int32_t ) atIndex: 5 ];
3086+                 [encoder setBytes: &p0      length: sizeof (int32_t ) atIndex: 6 ];
3087+                 [encoder setBytes: &p1      length: sizeof (int32_t ) atIndex: 7 ];
3088+                 [encoder setBytes: &IH      length: sizeof (int64_t ) atIndex: 8 ];
3089+                 [encoder setBytes: &IW      length: sizeof (int64_t ) atIndex: 9 ];
3090+                 [encoder setBytes: &OH      length: sizeof (int64_t ) atIndex: 10 ];
3091+                 [encoder setBytes: &OW      length: sizeof (int64_t ) atIndex: 11 ];
3092+                 [encoder setBytes: ¶llel_elements length: sizeof (int64_t ) atIndex: 12 ];
3093+ 
3094+                 [encoder dispatchThreadgroups: MTLSizeMake (n_tg, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (n_threads, 1 , 1 )];
3095+             } break ;
30043096       default :
30053097            {
30063098                GGML_LOG_ERROR (" %s : error: node %3d , op = %8s  not implemented\n "  , __func__, idx, ggml_op_name (dst->op ));
0 commit comments