@@ -499,6 +499,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
499499 GGML_METAL_KERNEL_TYPE_COS,
500500 GGML_METAL_KERNEL_TYPE_NEG,
501501 GGML_METAL_KERNEL_TYPE_SUM_ROWS,
502+ GGML_METAL_KERNEL_TYPE_MEAN,
502503 GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
503504 GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
504505 GGML_METAL_KERNEL_TYPE_ARGMAX,
@@ -1456,6 +1457,7 @@ @implementation GGMLMetalClass
14561457 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_COS, cos, true );
14571458 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_NEG, neg, true );
14581459 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true );
1460+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MEAN, mean, true );
14591461 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true );
14601462 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true );
14611463 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true );
@@ -1655,6 +1657,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
16551657 case GGML_OP_LOG:
16561658 return false ; // TODO: implement
16571659 case GGML_OP_SUM_ROWS:
1660+ case GGML_OP_MEAN:
16581661 case GGML_OP_SOFT_MAX:
16591662 case GGML_OP_GROUP_NORM:
16601663 return has_simdgroup_reduction && ggml_is_contiguous (op->src [0 ]);
@@ -2402,11 +2405,30 @@ static bool ggml_metal_encode_node(
24022405 [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
24032406 } break ;
24042407 case GGML_OP_SUM_ROWS:
2408+ case GGML_OP_MEAN:
24052409 {
24062410 GGML_ASSERT (src0->nb [0 ] == ggml_type_size (src0->type ));
24072411
2408- id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline ;
2412+ id <MTLComputePipelineState > pipeline = nil ;
2413+
2414+ switch (dst->op ) {
2415+ case GGML_OP_SUM_ROWS:
2416+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline ;
2417+ break ;
2418+ case GGML_OP_MEAN:
2419+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MEAN].pipeline ;
2420+ break ;
2421+ default :
2422+ GGML_ABORT (" fatal error" );
2423+ }
2424+
2425+ int nth = 32 ; // SIMD width
2426+
2427+ while (nth < ne00 && nth < (int ) pipeline.maxTotalThreadsPerThreadgroup ) {
2428+ nth *= 2 ;
2429+ }
24092430
2431+ nth = MIN (nth, ne00);
24102432
24112433 ggml_metal_kargs_sum_rows args = {
24122434 /* .ne00 =*/ ne00,
@@ -2436,11 +2458,12 @@ static bool ggml_metal_encode_node(
24362458 };
24372459
24382460 [encoder setComputePipelineState: pipeline];
2439- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2440- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
2441- [encoder setBytes: &args length: sizeof (args) atIndex: 2 ];
2461+ [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
2462+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
2463+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
2464+ [encoder setThreadgroupMemoryLength: 32 *sizeof (float ) atIndex: 0 ];
24422465
2443- [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
2466+ [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth , 1 , 1 )];
24442467 } break ;
24452468 case GGML_OP_SOFT_MAX:
24462469 {
0 commit comments