Skip to content

Commit b3964c1

Browse files
authored
metal : optimize FA vec for large sequences and BS <= 8 (ggml-org#15566)
* metal : optmize FA vec for large heads and sequences * metal : adjust small-batch mul mv kernels ggml-ci * batched-bench : fix total speed computation ggml-ci * cont : add comments ggml-ci
1 parent 79a5462 commit b3964c1

File tree

4 files changed

+183
-25
lines changed

4 files changed

+183
-25
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ typedef struct {
249249
uint64_t nb33;
250250
int32_t ne1;
251251
int32_t ne2;
252+
int32_t ne3;
252253
float scale;
253254
float max_bias;
254255
float m0;
@@ -257,6 +258,11 @@ typedef struct {
257258
float logit_softcap;
258259
} ggml_metal_kargs_flash_attn_ext;
259260

261+
typedef struct {
262+
int32_t nrows;
263+
int32_t ne20;
264+
} ggml_metal_kargs_flash_attn_ext_reduce;
265+
260266
typedef struct {
261267
int32_t ne00;
262268
int32_t ne02;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 111 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
291291
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
292292
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
293293
GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,
294+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2,
295+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3,
296+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4,
297+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5,
294298
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
295299
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
296300
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
@@ -575,6 +579,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
575579
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,
576580
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,
577581
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,
582+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE,
578583
GGML_METAL_KERNEL_TYPE_SET_I32,
579584
GGML_METAL_KERNEL_TYPE_SET_F32,
580585
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
@@ -1324,6 +1329,10 @@ @implementation GGMLMetalClass
13241329
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
13251330
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
13261331
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, has_simdgroup_reduction);
1332+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2, mul_mv_ext_f32_f32_r1_2, has_simdgroup_reduction);
1333+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3, mul_mv_ext_f32_f32_r1_3, has_simdgroup_reduction);
1334+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4, mul_mv_ext_f32_f32_r1_4, has_simdgroup_reduction);
1335+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5, mul_mv_ext_f32_f32_r1_5, has_simdgroup_reduction);
13271336
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
13281337
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
13291338
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
@@ -1609,6 +1618,7 @@ @implementation GGMLMetalClass
16091618
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction);
16101619
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction);
16111620
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction);
1621+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE, flash_attn_ext_reduce, has_simdgroup_reduction);
16121622
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
16131623
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
16141624
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
@@ -3385,15 +3395,16 @@ static int ggml_metal_encode_node(
33853395

33863396
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
33873397
// to the matrix-vector kernel
3388-
const int ne11_mm_min = 4;
3398+
const int ne11_mm_min = 8;
33893399

33903400
// first try to use small-batch mat-mv kernels
33913401
// these should be efficient for BS [2, ~8]
3392-
if (src1t == GGML_TYPE_F32 && (ne00%256 == 0) &&
3402+
if (src1t == GGML_TYPE_F32 && (ne00%128 == 0) &&
33933403
(
33943404
(
33953405
(
3396-
src0t == GGML_TYPE_F16 || // TODO: helper function
3406+
src0t == GGML_TYPE_F32 || // TODO: helper function
3407+
src0t == GGML_TYPE_F16 ||
33973408
src0t == GGML_TYPE_Q4_0 ||
33983409
src0t == GGML_TYPE_Q4_1 ||
33993410
src0t == GGML_TYPE_Q5_0 ||
@@ -3421,7 +3432,17 @@ static int ggml_metal_encode_node(
34213432
// values and there can be some tail effects when nsg is high. need to confirm this
34223433
//
34233434
const int nsg = 2; // num simdgroups per threadgroup
3424-
const int nxpsg = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup
3435+
3436+
// num threads along row per simdgroup
3437+
int nxpsg = 0;
3438+
if (ne00 % 256 == 0 && ne11 < 3) {
3439+
nxpsg = 16;
3440+
} else if (ne00 % 128 == 0) {
3441+
nxpsg = 8;
3442+
} else {
3443+
nxpsg = 4;
3444+
}
3445+
34253446
const int nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
34263447
const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup
34273448
int r1ptg = 4; // num src1 rows per threadgroup
@@ -3444,6 +3465,14 @@ static int ggml_metal_encode_node(
34443465
id<MTLComputePipelineState> pipeline = nil;
34453466

34463467
switch (src0->type) {
3468+
case GGML_TYPE_F32:
3469+
switch (r1ptg) {
3470+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2].pipeline; break;
3471+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3].pipeline; break;
3472+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4].pipeline; break;
3473+
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5].pipeline; break;
3474+
default: GGML_ABORT("not implemented");
3475+
} break;
34473476
case GGML_TYPE_F16:
34483477
switch (r1ptg) {
34493478
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline; break;
@@ -3598,7 +3627,7 @@ static int ggml_metal_encode_node(
35983627
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
35993628
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
36003629
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
3601-
case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
3630+
case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
36023631
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
36033632
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
36043633
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
@@ -5482,6 +5511,7 @@ static int ggml_metal_encode_node(
54825511
/*.nb33 =*/ nb33,
54835512
/*.ne1 =*/ ne1,
54845513
/*.ne2 =*/ ne2,
5514+
/*.ne3 =*/ ne3,
54855515
/*.scale =*/ scale,
54865516
/*.max_bias =*/ max_bias,
54875517
/*.m0 =*/ m0,
@@ -5505,7 +5535,6 @@ static int ggml_metal_encode_node(
55055535
} else {
55065536
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:5];
55075537
}
5508-
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
55095538

55105539
if (!use_vec_kernel) {
55115540
// half8x8 kernel
@@ -5531,7 +5560,7 @@ static int ggml_metal_encode_node(
55315560

55325561
while (true) {
55335562
const size_t smem = FATTN_SMEM(nsgmax);
5534-
if (smem > device.maxThreadgroupMemoryLength) {
5563+
if (smem > device.maxThreadgroupMemoryLength/2) {
55355564
break;
55365565
}
55375566
nsgmax *= 2;
@@ -5543,15 +5572,18 @@ static int ggml_metal_encode_node(
55435572

55445573
const size_t smem = FATTN_SMEM(nsg);
55455574

5575+
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
5576+
55465577
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
55475578
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
55485579
[encoder setThreadgroupMemoryLength:smem atIndex:0];
5549-
#undef FATTN_SMEM
55505580
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
5581+
#undef FATTN_SMEM
55515582
} else {
55525583
// half4x4 kernel
55535584
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
55545585
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
5586+
const int64_t nkpsg = 1*ncpsg; // TODO: make adjustable
55555587

55565588
GGML_ASSERT(nqptg <= 32);
55575589
GGML_ASSERT(nqptg % 1 == 0);
@@ -5561,37 +5593,100 @@ static int ggml_metal_encode_node(
55615593
// for each query, we load it as f16 in shared memory (ne00)
55625594
// and store the soft_max values and the mask
55635595
//
5564-
// ne00*(nsg)
5596+
// ne20*(nsg)
55655597
// each simdgroup has a full f32 head vector in shared mem to accumulate results
55665598
//
55675599
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16))
5600+
//#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)))*(sizeof(float)/2), 16))
55685601

55695602
int64_t nsgmax = 2;
55705603
while (true) {
55715604
const size_t smem = FATTN_SMEM(nsgmax);
5572-
if (smem > device.maxThreadgroupMemoryLength) {
5605+
// avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes
5606+
if (smem > device.maxThreadgroupMemoryLength/2) {
55735607
break;
55745608
}
55755609
nsgmax *= 2;
55765610
}
55775611
nsgmax /= 2;
55785612

55795613
// simdgroups per threadgroup (a.k.a. warps)
5580-
const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
5614+
const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
55815615

55825616
int64_t nsg = 1;
55835617
while (nsg <= nsgt) {
55845618
nsg *= 2;
55855619
}
55865620
nsg /= 2;
55875621

5588-
const size_t smem = FATTN_SMEM(nsg);
5622+
// workgroups
5623+
// each workgroup handles nsg*nkpsg cache values
5624+
uint16_t nwg = 1;
5625+
if (4*nsg*nkpsg >= ne11) {
5626+
const size_t smem = FATTN_SMEM(nsg);
55895627

5590-
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
5591-
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
5592-
[encoder setThreadgroupMemoryLength:smem atIndex:0];
5628+
//printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
5629+
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
5630+
5631+
// using 1 workgroup -> write the result directly into dst
5632+
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
5633+
[encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7];
5634+
5635+
[encoder setThreadgroupMemoryLength:smem atIndex:0];
5636+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
5637+
} else {
5638+
nwg = 32;
5639+
nsg = MIN(4, nsg);
5640+
5641+
const size_t smem = FATTN_SMEM(nsg);
5642+
5643+
//printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
5644+
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
5645+
5646+
// sanity checks
5647+
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
5648+
GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31));
5649+
5650+
const int32_t nrows = ne1*ne2*ne3;
5651+
5652+
// temp buffer for writing the results from each workgroup
5653+
// - ne20: the size of the head vector
5654+
// - + 2: the S and M values for each intermediate result
5655+
const size_t s_tmp = ggml_type_size(GGML_TYPE_F32)*(nrows*nwg*(ne20 + 2));
5656+
id<MTLBuffer> h_tmp = ggml_metal_mem_pool_alloc(mem_pool, s_tmp);
5657+
if (!h_tmp) {
5658+
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tmp);
5659+
return 0;
5660+
}
5661+
5662+
//printf("ne01 = %d, ne02 = %d, ne03 = %d, ne20 = %d\n", ne01, ne02, ne03, ne20);
5663+
//printf("needed memory: %.3f MiB\n", (float) (ne01*ne02*ne03*ne20*sizeof(float))/1024.0f/1024.0f);
5664+
5665+
[encoder setBuffer:h_tmp offset:0 atIndex:6];
5666+
[encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7];
5667+
5668+
[encoder setThreadgroupMemoryLength:smem atIndex:0];
5669+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
5670+
5671+
// reduce the results from the workgroups
5672+
{
5673+
ggml_metal_kargs_flash_attn_ext_reduce args0 = {
5674+
nrows,
5675+
ne20,
5676+
};
5677+
5678+
id<MTLComputePipelineState> pipeline0 = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE].pipeline;
5679+
5680+
[encoder setComputePipelineState:pipeline0];
5681+
[encoder setBytes:&args0 length:sizeof(args0) atIndex:0];
5682+
[encoder setBuffer:h_tmp offset:0 atIndex:1];
5683+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
5684+
5685+
//printf("ne1 = %d, ne2 = %d, ne3 = %d, ne20 = %d\n", ne1, ne2, ne3, ne20);
5686+
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(32*32, 1, 1)];
5687+
}
5688+
}
55935689
#undef FATTN_SMEM
5594-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
55955690
}
55965691
} break;
55975692
case GGML_OP_DUP:

0 commit comments

Comments
 (0)