@@ -291,6 +291,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
291
291
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
292
292
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
293
293
GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,
294
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2,
295
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3,
296
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4,
297
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5,
294
298
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
295
299
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
296
300
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
@@ -575,6 +579,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
575
579
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,
576
580
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,
577
581
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,
582
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE,
578
583
GGML_METAL_KERNEL_TYPE_SET_I32,
579
584
GGML_METAL_KERNEL_TYPE_SET_F32,
580
585
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
@@ -1324,6 +1329,10 @@ @implementation GGMLMetalClass
1324
1329
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
1325
1330
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
1326
1331
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, has_simdgroup_reduction);
1332
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2, mul_mv_ext_f32_f32_r1_2, has_simdgroup_reduction);
1333
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3, mul_mv_ext_f32_f32_r1_3, has_simdgroup_reduction);
1334
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4, mul_mv_ext_f32_f32_r1_4, has_simdgroup_reduction);
1335
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5, mul_mv_ext_f32_f32_r1_5, has_simdgroup_reduction);
1327
1336
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
1328
1337
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
1329
1338
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
@@ -1609,6 +1618,7 @@ @implementation GGMLMetalClass
1609
1618
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction);
1610
1619
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction);
1611
1620
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction);
1621
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE, flash_attn_ext_reduce, has_simdgroup_reduction);
1612
1622
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true );
1613
1623
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true );
1614
1624
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true );
@@ -3385,15 +3395,16 @@ static int ggml_metal_encode_node(
3385
3395
3386
3396
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
3387
3397
// to the matrix-vector kernel
3388
- const int ne11_mm_min = 4 ;
3398
+ const int ne11_mm_min = 8 ;
3389
3399
3390
3400
// first try to use small-batch mat-mv kernels
3391
3401
// these should be efficient for BS [2, ~8]
3392
- if (src1t == GGML_TYPE_F32 && (ne00%256 == 0 ) &&
3402
+ if (src1t == GGML_TYPE_F32 && (ne00%128 == 0 ) &&
3393
3403
(
3394
3404
(
3395
3405
(
3396
- src0t == GGML_TYPE_F16 || // TODO: helper function
3406
+ src0t == GGML_TYPE_F32 || // TODO: helper function
3407
+ src0t == GGML_TYPE_F16 ||
3397
3408
src0t == GGML_TYPE_Q4_0 ||
3398
3409
src0t == GGML_TYPE_Q4_1 ||
3399
3410
src0t == GGML_TYPE_Q5_0 ||
@@ -3421,7 +3432,17 @@ static int ggml_metal_encode_node(
3421
3432
// values and there can be some tail effects when nsg is high. need to confirm this
3422
3433
//
3423
3434
const int nsg = 2 ; // num simdgroups per threadgroup
3424
- const int nxpsg = ne11 < 3 ? 16 : 8 ; // num threads along row per simdgroup
3435
+
3436
+ // num threads along row per simdgroup
3437
+ int nxpsg = 0 ;
3438
+ if (ne00 % 256 == 0 && ne11 < 3 ) {
3439
+ nxpsg = 16 ;
3440
+ } else if (ne00 % 128 == 0 ) {
3441
+ nxpsg = 8 ;
3442
+ } else {
3443
+ nxpsg = 4 ;
3444
+ }
3445
+
3425
3446
const int nypsg = 32 /nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
3426
3447
const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup
3427
3448
int r1ptg = 4 ; // num src1 rows per threadgroup
@@ -3444,6 +3465,14 @@ static int ggml_metal_encode_node(
3444
3465
id <MTLComputePipelineState > pipeline = nil ;
3445
3466
3446
3467
switch (src0->type ) {
3468
+ case GGML_TYPE_F32:
3469
+ switch (r1ptg) {
3470
+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2].pipeline ; break ;
3471
+ case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3].pipeline ; break ;
3472
+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4].pipeline ; break ;
3473
+ case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5].pipeline ; break ;
3474
+ default : GGML_ABORT (" not implemented" );
3475
+ } break ;
3447
3476
case GGML_TYPE_F16:
3448
3477
switch (r1ptg) {
3449
3478
case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline ; break ;
@@ -3598,7 +3627,7 @@ static int ggml_metal_encode_node(
3598
3627
case GGML_TYPE_Q5_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline ; break ;
3599
3628
case GGML_TYPE_Q5_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline ; break ;
3600
3629
case GGML_TYPE_Q8_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline ; break ;
3601
- case GGML_TYPE_MXFP4: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline ; break ;
3630
+ case GGML_TYPE_MXFP4: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline ; break ;
3602
3631
case GGML_TYPE_Q2_K: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline ; break ;
3603
3632
case GGML_TYPE_Q3_K: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline ; break ;
3604
3633
case GGML_TYPE_Q4_K: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline ; break ;
@@ -5482,6 +5511,7 @@ static int ggml_metal_encode_node(
5482
5511
/* .nb33 =*/ nb33,
5483
5512
/* .ne1 =*/ ne1,
5484
5513
/* .ne2 =*/ ne2,
5514
+ /* .ne3 =*/ ne3,
5485
5515
/* .scale =*/ scale,
5486
5516
/* .max_bias =*/ max_bias,
5487
5517
/* .m0 =*/ m0,
@@ -5505,7 +5535,6 @@ static int ggml_metal_encode_node(
5505
5535
} else {
5506
5536
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 5 ];
5507
5537
}
5508
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 6 ];
5509
5538
5510
5539
if (!use_vec_kernel) {
5511
5540
// half8x8 kernel
@@ -5531,7 +5560,7 @@ static int ggml_metal_encode_node(
5531
5560
5532
5561
while (true ) {
5533
5562
const size_t smem = FATTN_SMEM (nsgmax);
5534
- if (smem > device.maxThreadgroupMemoryLength ) {
5563
+ if (smem > device.maxThreadgroupMemoryLength / 2 ) {
5535
5564
break ;
5536
5565
}
5537
5566
nsgmax *= 2 ;
@@ -5543,15 +5572,18 @@ static int ggml_metal_encode_node(
5543
5572
5544
5573
const size_t smem = FATTN_SMEM (nsg);
5545
5574
5575
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 6 ];
5576
+
5546
5577
// printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
5547
5578
GGML_ASSERT (smem <= device.maxThreadgroupMemoryLength );
5548
5579
[encoder setThreadgroupMemoryLength: smem atIndex: 0 ];
5549
- #undef FATTN_SMEM
5550
5580
[encoder dispatchThreadgroups: MTLSizeMake ((ne01 + nqptg - 1 )/nqptg, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (32 , nsg, 1 )];
5581
+ #undef FATTN_SMEM
5551
5582
} else {
5552
5583
// half4x4 kernel
5553
5584
const int64_t nqptg = 1 ; // queries per threadgroup !! sync with kernel template arguments !!
5554
5585
const int64_t ncpsg = 32 ; // cache values per simdgroup !! sync with kernel template arguments !!
5586
+ const int64_t nkpsg = 1 *ncpsg; // TODO: make adjustable
5555
5587
5556
5588
GGML_ASSERT (nqptg <= 32 );
5557
5589
GGML_ASSERT (nqptg % 1 == 0 );
@@ -5561,37 +5593,100 @@ static int ggml_metal_encode_node(
5561
5593
// for each query, we load it as f16 in shared memory (ne00)
5562
5594
// and store the soft_max values and the mask
5563
5595
//
5564
- // ne00 *(nsg)
5596
+ // ne20 *(nsg)
5565
5597
// each simdgroup has a full f32 head vector in shared mem to accumulate results
5566
5598
//
5567
5599
#define FATTN_SMEM (nsg ) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128 ) + 4 *ncpsg*(nsg)) + 2 *ne20*(nsg))*(sizeof (float )/2 ), 16 ))
5600
+ // #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)))*(sizeof(float)/2), 16))
5568
5601
5569
5602
int64_t nsgmax = 2 ;
5570
5603
while (true ) {
5571
5604
const size_t smem = FATTN_SMEM (nsgmax);
5572
- if (smem > device.maxThreadgroupMemoryLength ) {
5605
+ // avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes
5606
+ if (smem > device.maxThreadgroupMemoryLength /2 ) {
5573
5607
break ;
5574
5608
}
5575
5609
nsgmax *= 2 ;
5576
5610
}
5577
5611
nsgmax /= 2 ;
5578
5612
5579
5613
// simdgroups per threadgroup (a.k.a. warps)
5580
- const int64_t nsgt = MAX (2 , MIN (nsgmax, MIN (ne11/ncpsg , (int64_t ) pipeline.maxTotalThreadsPerThreadgroup /32 )));
5614
+ const int64_t nsgt = MAX (2 , MIN (nsgmax, MIN (( ne11 + nkpsg - 1 )/(nkpsg) , (int64_t ) pipeline.maxTotalThreadsPerThreadgroup /32 )));
5581
5615
5582
5616
int64_t nsg = 1 ;
5583
5617
while (nsg <= nsgt) {
5584
5618
nsg *= 2 ;
5585
5619
}
5586
5620
nsg /= 2 ;
5587
5621
5588
- const size_t smem = FATTN_SMEM (nsg);
5622
+ // workgroups
5623
+ // each workgroup handles nsg*nkpsg cache values
5624
+ uint16_t nwg = 1 ;
5625
+ if (4 *nsg*nkpsg >= ne11) {
5626
+ const size_t smem = FATTN_SMEM (nsg);
5589
5627
5590
- // printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
5591
- GGML_ASSERT (smem <= device.maxThreadgroupMemoryLength );
5592
- [encoder setThreadgroupMemoryLength: smem atIndex: 0 ];
5628
+ // printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
5629
+ GGML_ASSERT (smem <= device.maxThreadgroupMemoryLength );
5630
+
5631
+ // using 1 workgroup -> write the result directly into dst
5632
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 6 ];
5633
+ [encoder setBytes: &nwg length: sizeof (uint16_t ) atIndex: 7 ];
5634
+
5635
+ [encoder setThreadgroupMemoryLength: smem atIndex: 0 ];
5636
+ [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + nqptg - 1 )/nqptg, ne02, ne03*nwg) threadsPerThreadgroup: MTLSizeMake (32 , nsg, 1 )];
5637
+ } else {
5638
+ nwg = 32 ;
5639
+ nsg = MIN (4 , nsg);
5640
+
5641
+ const size_t smem = FATTN_SMEM (nsg);
5642
+
5643
+ // printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
5644
+ GGML_ASSERT (smem <= device.maxThreadgroupMemoryLength );
5645
+
5646
+ // sanity checks
5647
+ GGML_ASSERT (ne01*ne02*ne03 == ne1*ne2*ne3);
5648
+ GGML_ASSERT (ne1*ne2*ne3 <= (1u << 31 ));
5649
+
5650
+ const int32_t nrows = ne1*ne2*ne3;
5651
+
5652
+ // temp buffer for writing the results from each workgroup
5653
+ // - ne20: the size of the head vector
5654
+ // - + 2: the S and M values for each intermediate result
5655
+ const size_t s_tmp = ggml_type_size (GGML_TYPE_F32)*(nrows*nwg*(ne20 + 2 ));
5656
+ id <MTLBuffer > h_tmp = ggml_metal_mem_pool_alloc (mem_pool, s_tmp);
5657
+ if (!h_tmp) {
5658
+ GGML_LOG_ERROR (" %s : failed to allocate buffer from memory pool, size = %zu \n " , __func__, s_tmp);
5659
+ return 0 ;
5660
+ }
5661
+
5662
+ // printf("ne01 = %d, ne02 = %d, ne03 = %d, ne20 = %d\n", ne01, ne02, ne03, ne20);
5663
+ // printf("needed memory: %.3f MiB\n", (float) (ne01*ne02*ne03*ne20*sizeof(float))/1024.0f/1024.0f);
5664
+
5665
+ [encoder setBuffer: h_tmp offset: 0 atIndex: 6 ];
5666
+ [encoder setBytes: &nwg length: sizeof (uint16_t ) atIndex: 7 ];
5667
+
5668
+ [encoder setThreadgroupMemoryLength: smem atIndex: 0 ];
5669
+ [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + nqptg - 1 )/nqptg, ne02, ne03*nwg) threadsPerThreadgroup: MTLSizeMake (32 , nsg, 1 )];
5670
+
5671
+ // reduce the results from the workgroups
5672
+ {
5673
+ ggml_metal_kargs_flash_attn_ext_reduce args0 = {
5674
+ nrows,
5675
+ ne20,
5676
+ };
5677
+
5678
+ id <MTLComputePipelineState > pipeline0 = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE].pipeline ;
5679
+
5680
+ [encoder setComputePipelineState: pipeline0];
5681
+ [encoder setBytes: &args0 length: sizeof (args0) atIndex: 0 ];
5682
+ [encoder setBuffer: h_tmp offset: 0 atIndex: 1 ];
5683
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
5684
+
5685
+ // printf("ne1 = %d, ne2 = %d, ne3 = %d, ne20 = %d\n", ne1, ne2, ne3, ne20);
5686
+ [encoder dispatchThreadgroups: MTLSizeMake (nrows, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (32 *32 , 1 , 1 )];
5687
+ }
5688
+ }
5593
5689
#undef FATTN_SMEM
5594
- [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + nqptg - 1 )/nqptg, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (32 , nsg, 1 )];
5595
5690
}
5596
5691
} break ;
5597
5692
case GGML_OP_DUP:
0 commit comments