@@ -291,6 +291,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
291291 GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
292292 GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
293293 GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,
294+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2,
295+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3,
296+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4,
297+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5,
294298 GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
295299 GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
296300 GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
@@ -575,6 +579,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
575579 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,
576580 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,
577581 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,
582+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE,
578583 GGML_METAL_KERNEL_TYPE_SET_I32,
579584 GGML_METAL_KERNEL_TYPE_SET_F32,
580585 GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
@@ -1324,6 +1329,10 @@ @implementation GGMLMetalClass
13241329 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
13251330 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
13261331 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, has_simdgroup_reduction);
1332+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2, mul_mv_ext_f32_f32_r1_2, has_simdgroup_reduction);
1333+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3, mul_mv_ext_f32_f32_r1_3, has_simdgroup_reduction);
1334+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4, mul_mv_ext_f32_f32_r1_4, has_simdgroup_reduction);
1335+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5, mul_mv_ext_f32_f32_r1_5, has_simdgroup_reduction);
13271336 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
13281337 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
13291338 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
@@ -1609,6 +1618,7 @@ @implementation GGMLMetalClass
16091618 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction);
16101619 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction);
16111620 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction);
1621+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE, flash_attn_ext_reduce, has_simdgroup_reduction);
16121622 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true );
16131623 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true );
16141624 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true );
@@ -3385,15 +3395,16 @@ static int ggml_metal_encode_node(
33853395
33863396 // find the break-even point where the matrix-matrix kernel becomes more efficient compared
33873397 // to the matrix-vector kernel
3388- const int ne11_mm_min = 4 ;
3398+ const int ne11_mm_min = 8 ;
33893399
33903400 // first try to use small-batch mat-mv kernels
33913401 // these should be efficient for BS [2, ~8]
3392- if (src1t == GGML_TYPE_F32 && (ne00%256 == 0 ) &&
3402+ if (src1t == GGML_TYPE_F32 && (ne00%128 == 0 ) &&
33933403 (
33943404 (
33953405 (
3396- src0t == GGML_TYPE_F16 || // TODO: helper function
3406+ src0t == GGML_TYPE_F32 || // TODO: helper function
3407+ src0t == GGML_TYPE_F16 ||
33973408 src0t == GGML_TYPE_Q4_0 ||
33983409 src0t == GGML_TYPE_Q4_1 ||
33993410 src0t == GGML_TYPE_Q5_0 ||
@@ -3421,7 +3432,17 @@ static int ggml_metal_encode_node(
34213432 // values and there can be some tail effects when nsg is high. need to confirm this
34223433 //
34233434 const int nsg = 2 ; // num simdgroups per threadgroup
3424- const int nxpsg = ne11 < 3 ? 16 : 8 ; // num threads along row per simdgroup
3435+
3436+ // num threads along row per simdgroup
3437+ int nxpsg = 0 ;
3438+ if (ne00 % 256 == 0 && ne11 < 3 ) {
3439+ nxpsg = 16 ;
3440+ } else if (ne00 % 128 == 0 ) {
3441+ nxpsg = 8 ;
3442+ } else {
3443+ nxpsg = 4 ;
3444+ }
3445+
34253446 const int nypsg = 32 /nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
34263447 const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup
34273448 int r1ptg = 4 ; // num src1 rows per threadgroup
@@ -3444,6 +3465,14 @@ static int ggml_metal_encode_node(
34443465 id <MTLComputePipelineState > pipeline = nil ;
34453466
34463467 switch (src0->type ) {
3468+ case GGML_TYPE_F32:
3469+ switch (r1ptg) {
3470+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2].pipeline ; break ;
3471+ case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3].pipeline ; break ;
3472+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4].pipeline ; break ;
3473+ case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5].pipeline ; break ;
3474+ default : GGML_ABORT (" not implemented" );
3475+ } break ;
34473476 case GGML_TYPE_F16:
34483477 switch (r1ptg) {
34493478 case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline ; break ;
@@ -3598,7 +3627,7 @@ static int ggml_metal_encode_node(
35983627 case GGML_TYPE_Q5_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline ; break ;
35993628 case GGML_TYPE_Q5_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline ; break ;
36003629 case GGML_TYPE_Q8_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline ; break ;
3601- case GGML_TYPE_MXFP4: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline ; break ;
3630+ case GGML_TYPE_MXFP4: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline ; break ;
36023631 case GGML_TYPE_Q2_K: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline ; break ;
36033632 case GGML_TYPE_Q3_K: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline ; break ;
36043633 case GGML_TYPE_Q4_K: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline ; break ;
@@ -5482,6 +5511,7 @@ static int ggml_metal_encode_node(
54825511 /* .nb33 =*/ nb33,
54835512 /* .ne1 =*/ ne1,
54845513 /* .ne2 =*/ ne2,
5514+ /* .ne3 =*/ ne3,
54855515 /* .scale =*/ scale,
54865516 /* .max_bias =*/ max_bias,
54875517 /* .m0 =*/ m0,
@@ -5505,7 +5535,6 @@ static int ggml_metal_encode_node(
55055535 } else {
55065536 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 5 ];
55075537 }
5508- [encoder setBuffer: id_dst offset: offs_dst atIndex: 6 ];
55095538
55105539 if (!use_vec_kernel) {
55115540 // half8x8 kernel
@@ -5531,7 +5560,7 @@ static int ggml_metal_encode_node(
55315560
55325561 while (true ) {
55335562 const size_t smem = FATTN_SMEM (nsgmax);
5534- if (smem > device.maxThreadgroupMemoryLength ) {
5563+ if (smem > device.maxThreadgroupMemoryLength / 2 ) {
55355564 break ;
55365565 }
55375566 nsgmax *= 2 ;
@@ -5543,15 +5572,18 @@ static int ggml_metal_encode_node(
55435572
55445573 const size_t smem = FATTN_SMEM (nsg);
55455574
5575+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 6 ];
5576+
55465577 // printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
55475578 GGML_ASSERT (smem <= device.maxThreadgroupMemoryLength );
55485579 [encoder setThreadgroupMemoryLength: smem atIndex: 0 ];
5549- #undef FATTN_SMEM
55505580 [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + nqptg - 1 )/nqptg, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (32 , nsg, 1 )];
5581+ #undef FATTN_SMEM
55515582 } else {
55525583 // half4x4 kernel
55535584 const int64_t nqptg = 1 ; // queries per threadgroup !! sync with kernel template arguments !!
55545585 const int64_t ncpsg = 32 ; // cache values per simdgroup !! sync with kernel template arguments !!
5586+ const int64_t nkpsg = 1 *ncpsg; // TODO: make adjustable
55555587
55565588 GGML_ASSERT (nqptg <= 32 );
55575589 GGML_ASSERT (nqptg % 1 == 0 );
@@ -5561,37 +5593,100 @@ static int ggml_metal_encode_node(
55615593 // for each query, we load it as f16 in shared memory (ne00)
55625594 // and store the soft_max values and the mask
55635595 //
5564- // ne00 *(nsg)
5596+ // ne20 *(nsg)
55655597 // each simdgroup has a full f32 head vector in shared mem to accumulate results
55665598 //
55675599#define FATTN_SMEM (nsg ) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128 ) + 4 *ncpsg*(nsg)) + 2 *ne20*(nsg))*(sizeof (float )/2 ), 16 ))
5600+ // #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)))*(sizeof(float)/2), 16))
55685601
55695602 int64_t nsgmax = 2 ;
55705603 while (true ) {
55715604 const size_t smem = FATTN_SMEM (nsgmax);
5572- if (smem > device.maxThreadgroupMemoryLength ) {
5605+ // avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes
5606+ if (smem > device.maxThreadgroupMemoryLength /2 ) {
55735607 break ;
55745608 }
55755609 nsgmax *= 2 ;
55765610 }
55775611 nsgmax /= 2 ;
55785612
55795613 // simdgroups per threadgroup (a.k.a. warps)
5580- const int64_t nsgt = MAX (2 , MIN (nsgmax, MIN (ne11/ncpsg , (int64_t ) pipeline.maxTotalThreadsPerThreadgroup /32 )));
5614+ const int64_t nsgt = MAX (2 , MIN (nsgmax, MIN (( ne11 + nkpsg - 1 )/(nkpsg) , (int64_t ) pipeline.maxTotalThreadsPerThreadgroup /32 )));
55815615
55825616 int64_t nsg = 1 ;
55835617 while (nsg <= nsgt) {
55845618 nsg *= 2 ;
55855619 }
55865620 nsg /= 2 ;
55875621
5588- const size_t smem = FATTN_SMEM (nsg);
5622+ // workgroups
5623+ // each workgroup handles nsg*nkpsg cache values
5624+ uint16_t nwg = 1 ;
5625+ if (4 *nsg*nkpsg >= ne11) {
5626+ const size_t smem = FATTN_SMEM (nsg);
55895627
5590- // printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
5591- GGML_ASSERT (smem <= device.maxThreadgroupMemoryLength );
5592- [encoder setThreadgroupMemoryLength: smem atIndex: 0 ];
5628+ // printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
5629+ GGML_ASSERT (smem <= device.maxThreadgroupMemoryLength );
5630+
5631+ // using 1 workgroup -> write the result directly into dst
5632+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 6 ];
5633+ [encoder setBytes: &nwg length: sizeof (uint16_t ) atIndex: 7 ];
5634+
5635+ [encoder setThreadgroupMemoryLength: smem atIndex: 0 ];
5636+ [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + nqptg - 1 )/nqptg, ne02, ne03*nwg) threadsPerThreadgroup: MTLSizeMake (32 , nsg, 1 )];
5637+ } else {
5638+ nwg = 32 ;
5639+ nsg = MIN (4 , nsg);
5640+
5641+ const size_t smem = FATTN_SMEM (nsg);
5642+
5643+ // printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
5644+ GGML_ASSERT (smem <= device.maxThreadgroupMemoryLength );
5645+
5646+ // sanity checks
5647+ GGML_ASSERT (ne01*ne02*ne03 == ne1*ne2*ne3);
5648+ GGML_ASSERT (ne1*ne2*ne3 <= (1u << 31 ));
5649+
5650+ const int32_t nrows = ne1*ne2*ne3;
5651+
5652+ // temp buffer for writing the results from each workgroup
5653+ // - ne20: the size of the head vector
5654+ // - + 2: the S and M values for each intermediate result
5655+ const size_t s_tmp = ggml_type_size (GGML_TYPE_F32)*(nrows*nwg*(ne20 + 2 ));
5656+ id <MTLBuffer > h_tmp = ggml_metal_mem_pool_alloc (mem_pool, s_tmp);
5657+ if (!h_tmp) {
5658+ GGML_LOG_ERROR (" %s : failed to allocate buffer from memory pool, size = %zu \n " , __func__, s_tmp);
5659+ return 0 ;
5660+ }
5661+
5662+ // printf("ne01 = %d, ne02 = %d, ne03 = %d, ne20 = %d\n", ne01, ne02, ne03, ne20);
5663+ // printf("needed memory: %.3f MiB\n", (float) (ne01*ne02*ne03*ne20*sizeof(float))/1024.0f/1024.0f);
5664+
5665+ [encoder setBuffer: h_tmp offset: 0 atIndex: 6 ];
5666+ [encoder setBytes: &nwg length: sizeof (uint16_t ) atIndex: 7 ];
5667+
5668+ [encoder setThreadgroupMemoryLength: smem atIndex: 0 ];
5669+ [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + nqptg - 1 )/nqptg, ne02, ne03*nwg) threadsPerThreadgroup: MTLSizeMake (32 , nsg, 1 )];
5670+
5671+ // reduce the results from the workgroups
5672+ {
5673+ ggml_metal_kargs_flash_attn_ext_reduce args0 = {
5674+ nrows,
5675+ ne20,
5676+ };
5677+
5678+ id <MTLComputePipelineState > pipeline0 = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE].pipeline ;
5679+
5680+ [encoder setComputePipelineState: pipeline0];
5681+ [encoder setBytes: &args0 length: sizeof (args0) atIndex: 0 ];
5682+ [encoder setBuffer: h_tmp offset: 0 atIndex: 1 ];
5683+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
5684+
5685+ // printf("ne1 = %d, ne2 = %d, ne3 = %d, ne20 = %d\n", ne1, ne2, ne3, ne20);
5686+ [encoder dispatchThreadgroups: MTLSizeMake (nrows, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (32 *32 , 1 , 1 )];
5687+ }
5688+ }
55935689#undef FATTN_SMEM
5594- [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + nqptg - 1 )/nqptg, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (32 , nsg, 1 )];
55955690 }
55965691 } break ;
55975692 case GGML_OP_DUP:
0 commit comments