Skip to content

Commit a92bdd9

Browse files
committed
cont : add comments
ggml-ci
1 parent 6d0b222 commit a92bdd9

File tree

3 files changed

+28
-22
lines changed

3 files changed

+28
-22
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,10 +259,8 @@ typedef struct {
259259
} ggml_metal_kargs_flash_attn_ext;
260260

261261
typedef struct {
262+
int32_t nrows;
262263
int32_t ne20;
263-
int32_t ne1;
264-
int32_t ne2;
265-
int32_t ne3;
266264
} ggml_metal_kargs_flash_attn_ext_reduce;
267265

268266
typedef struct {

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5583,7 +5583,7 @@ static int ggml_metal_encode_node(
55835583
// half4x4 kernel
55845584
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
55855585
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
5586-
const int64_t nkpsg = 1*ncpsg;
5586+
const int64_t nkpsg = 1*ncpsg; // TODO: make adjustable
55875587

55885588
GGML_ASSERT(nqptg <= 32);
55895589
GGML_ASSERT(nqptg % 1 == 0);
@@ -5602,6 +5602,7 @@ static int ggml_metal_encode_node(
56025602
int64_t nsgmax = 2;
56035603
while (true) {
56045604
const size_t smem = FATTN_SMEM(nsgmax);
5605+
// avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes
56055606
if (smem > device.maxThreadgroupMemoryLength/2) {
56065607
break;
56075608
}
@@ -5642,8 +5643,16 @@ static int ggml_metal_encode_node(
56425643
//printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
56435644
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
56445645

5645-
// tokens per expert
5646-
const size_t s_tmp = ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*ne20 + ne01*ne02*ne03*nwg*2);
5646+
// sanity checks
5647+
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
5648+
GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31));
5649+
5650+
const int32_t nrows = ne1*ne2*ne3;
5651+
5652+
// temp buffer for writing the results from each workgroup
5653+
// - ne20: the size of the head vector
5654+
// - + 2: the S and M values for each intermediate result
5655+
const size_t s_tmp = ggml_type_size(GGML_TYPE_F32)*(nrows*nwg*(ne20 + 2));
56475656
id<MTLBuffer> h_tmp = ggml_metal_mem_pool_alloc(mem_pool, s_tmp);
56485657
if (!h_tmp) {
56495658
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tmp);
@@ -5662,10 +5671,8 @@ static int ggml_metal_encode_node(
56625671
// reduce the results from the workgroups
56635672
{
56645673
ggml_metal_kargs_flash_attn_ext_reduce args0 = {
5674+
nrows,
56655675
ne20,
5666-
ne1,
5667-
ne2,
5668-
ne3,
56695676
};
56705677

56715678
id<MTLComputePipelineState> pipeline0 = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE].pipeline;
@@ -5676,7 +5683,7 @@ static int ggml_metal_encode_node(
56765683
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
56775684

56785685
//printf("ne1 = %d, ne2 = %d, ne3 = %d, ne20 = %d\n", ne1, ne2, ne3, ne20);
5679-
[encoder dispatchThreadgroups:MTLSizeMake((uint64_t) ne1*ne2*ne3, 1, 1) threadsPerThreadgroup:MTLSizeMake(32*32, 1, 1)];
5686+
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(32*32, 1, 1)];
56805687
}
56815688
}
56825689
#undef FATTN_SMEM

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5124,21 +5124,23 @@ kernel void kernel_flash_attn_ext_vec(
51245124

51255125
// final rescale with 1/S and store to global memory
51265126
if (sgitg == 0) {
5127-
device float * dst1 = (device float *) dst + (((uint64_t)args.ne3*args.ne2*args.ne1)*DV)*nwg;
5127+
const int64_t nrows = args.ne3*args.ne2*args.ne1;
5128+
const int64_t rid = iq3*args.ne2*args.ne1 + iq2 + iq1*args.ne1;
5129+
51285130
device float4 * dst4 = (device float4 *) dst;
5131+
device float * dst1 = (device float *) dst + nrows*DV*nwg; // the S and M are stored after the results
51295132

51305133
const float S = nwg == 1 ? 1.0f/ss[0] : 1.0f;
51315134

5132-
const uint64_t rid = (uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1;
5133-
51345135
// interleave the workgroup data
51355136
for (short i = tiisg; i < DV4; i += NW) {
5136-
dst4[(rid*DV4 + i)*nwg + iwg] = (float4) sr4[i]*S;
5137+
dst4[rid*DV4*nwg + nwg*i + iwg] = (float4) sr4[i]*S;
51375138
}
51385139

5140+
// store S and M
51395141
if (nwg > 1 && tiisg == 0) {
5140-
dst1[(rid*2)*nwg + 2*iwg + 0] = ss[0];
5141-
dst1[(rid*2)*nwg + 2*iwg + 1] = ss[1];
5142+
dst1[rid*(2*nwg) + 2*iwg + 0] = ss[0];
5143+
dst1[rid*(2*nwg) + 2*iwg + 1] = ss[1];
51425144
}
51435145
}
51445146
}
@@ -5253,19 +5255,18 @@ kernel void kernel_flash_attn_ext_reduce(
52535255
const short DV4 = DV/4;
52545256

52555257
device const float4 * htmp4 = (device const float4 *) htmp + rid*DV4*nwg;
5256-
device const float * ss = (device const float *) htmp + (uint64_t)args.ne3*args.ne2*args.ne1*DV*nwg;
5258+
device const float * ss = (device const float *) htmp + (uint64_t)args.nrows*DV*nwg;
52575259
device float4 * dst4 = (device float4 *) dst + rid*DV4;
52585260

5259-
float S = ss[(rid*2)*nwg + 2*iwg + 0];
5260-
float M = ss[(rid*2)*nwg + 2*iwg + 1];
5261-
5262-
const float m = simd_max(M);
5261+
float S = ss[rid*(2*nwg) + 2*iwg + 0];
5262+
float M = ss[rid*(2*nwg) + 2*iwg + 1];
52635263

5264+
const float m = simd_max(M);
52645265
const float ms = exp(M - m);
52655266

52665267
S = 1.0f/simd_sum(S*ms);
52675268

5268-
for (short i = sgitg; i < DV4; i += nwg) {
5269+
for (int i = sgitg; i < DV4; i += nwg) {
52695270
const float4 v = simd_sum(htmp4[i*nwg + iwg]*ms);
52705271

52715272
if (iwg == 0) {

0 commit comments

Comments
 (0)