@@ -272,6 +272,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
272272 GGML_METAL_KERNEL_TYPE_SIN,
273273 GGML_METAL_KERNEL_TYPE_COS,
274274 GGML_METAL_KERNEL_TYPE_SUM_ROWS,
275+ GGML_METAL_KERNEL_TYPE_AVG_POOL_2D_F32,
276+ GGML_METAL_KERNEL_TYPE_MAX_POOL_2D_F32,
275277
276278 GGML_METAL_KERNEL_TYPE_COUNT
277279};
@@ -716,6 +718,8 @@ @implementation GGMLMetalClass
716718 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SIN, sin, true );
717719 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_COS, cos, true );
718720 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true );
721+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_AVG_POOL_2D_F32, avg_pool_2d_f32, true );
722+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MAX_POOL_2D_F32, max_pool_2d_f32, true );
719723 }
720724
721725 [metal_library release ];
@@ -844,8 +848,9 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
844848 case GGML_OP_IM2COL:
845849 return op->src [0 ]->type == GGML_TYPE_F16;
846850 case GGML_OP_POOL_1D:
847- case GGML_OP_POOL_2D:
848851 return false ;
852+ case GGML_OP_POOL_2D:
853+ return true ;
849854 case GGML_OP_UPSCALE:
850855 case GGML_OP_PAD:
851856 case GGML_OP_ARANGE:
@@ -3001,6 +3006,63 @@ static void ggml_metal_encode_node(
30013006
30023007 [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
30033008 } break ;
3009+ case GGML_OP_POOL_2D:
3010+ {
3011+ GGML_ASSERT (src0t == GGML_TYPE_F32 && src0t == dstt);
3012+
3013+ const int32_t * opts = dst->op_params ;
3014+ enum ggml_op_pool op = opts[0 ];
3015+
3016+ id <MTLComputePipelineState > pipeline = nil ;
3017+ switch (src0t) {
3018+ case GGML_TYPE_F32: {
3019+ switch (op) {
3020+ case GGML_OP_POOL_AVG:
3021+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_AVG_POOL_2D_F32].pipeline ; break ;
3022+ case GGML_OP_POOL_MAX:
3023+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MAX_POOL_2D_F32].pipeline ; break ;
3024+ default : GGML_ASSERT (false && " not implemented" );
3025+ }
3026+ } break ;
3027+ default : GGML_ASSERT (false && " not implemented" );
3028+ }
3029+
3030+ const int32_t k0 = opts[1 ];
3031+ const int32_t k1 = opts[2 ];
3032+ const int32_t s0 = opts[3 ];
3033+ const int32_t s1 = opts[4 ];
3034+ const int32_t p0 = opts[5 ];
3035+ const int32_t p1 = opts[6 ];
3036+
3037+ const int64_t IH = src0->ne [1 ];
3038+ const int64_t IW = src0->ne [0 ];
3039+
3040+ const int64_t N = dst->ne [3 ];
3041+ const int64_t OC = dst->ne [2 ];
3042+ const int64_t OH = dst->ne [1 ];
3043+ const int64_t OW = dst->ne [0 ];
3044+
3045+ const int64_t parallel_elements = N * OC * OH * OW;
3046+ const int64_t n_threads = MIN ((int64_t )[pipeline maxTotalThreadsPerThreadgroup ], parallel_elements);
3047+ const int64_t n_tg = (parallel_elements + n_threads - 1 ) / n_threads;
3048+
3049+ [encoder setComputePipelineState: pipeline];
3050+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
3051+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
3052+ [encoder setBytes: &k0 length: sizeof (int32_t ) atIndex: 2 ];
3053+ [encoder setBytes: &k1 length: sizeof (int32_t ) atIndex: 3 ];
3054+ [encoder setBytes: &s0 length: sizeof (int32_t ) atIndex: 4 ];
3055+ [encoder setBytes: &s1 length: sizeof (int32_t ) atIndex: 5 ];
3056+ [encoder setBytes: &p0 length: sizeof (int32_t ) atIndex: 6 ];
3057+ [encoder setBytes: &p1 length: sizeof (int32_t ) atIndex: 7 ];
3058+ [encoder setBytes: &IH length: sizeof (int64_t ) atIndex: 8 ];
3059+ [encoder setBytes: &IW length: sizeof (int64_t ) atIndex: 9 ];
3060+ [encoder setBytes: &OH length: sizeof (int64_t ) atIndex: 10 ];
3061+ [encoder setBytes: &OW length: sizeof (int64_t ) atIndex: 11 ];
3062+ [encoder setBytes: ¶llel_elements length: sizeof (int64_t ) atIndex: 12 ];
3063+
3064+ [encoder dispatchThreadgroups: MTLSizeMake (n_tg, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (n_threads, 1 , 1 )];
3065+ } break ;
30043066 default :
30053067 {
30063068 GGML_LOG_ERROR (" %s : error: node %3d , op = %8s not implemented\n " , __func__, idx, ggml_op_name (dst->op ));
0 commit comments