Skip to content

Commit 389e7e4

Browse files
committed
cont : improve, comments
ggml-ci
1 parent 063be20 commit 389e7e4

File tree

1 file changed

+28
-21
lines changed

1 file changed

+28
-21
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,6 +1345,17 @@ @implementation GGMLMetalClass
13451345
return res;
13461346
}
13471347

1348+
// return true if we should use the FA vector kernel for this op
1349+
static bool ggml_metal_flash_attn_ext_use_vec(const struct ggml_tensor * op) {
1350+
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
1351+
1352+
const int64_t ne00 = op->src[0]->ne[0]; // head size
1353+
const int64_t ne01 = op->src[0]->ne[1]; // batch size
1354+
1355+
// use vec kernel if the batch size is small and if the head size is supported
1356+
return (ne01 < 20) && (ne00 % 32 == 0);
1357+
}
1358+
13481359
static id<MTLComputePipelineState> ggml_metal_get_pipeline_flash_attn_ext(
13491360
ggml_backend_t backend, struct ggml_tensor * op,
13501361
bool has_mask,
@@ -5067,9 +5078,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
50675078

50685079
GGML_ASSERT(ne01 < 65536);
50695080

5070-
// use non-vec kernel if the batch size is large or if the vec-kernel is not supported for this head size
5071-
// [TAG_METAL_EXTRA_SIZE_OP_FLASH_ATTN_EXT]
5072-
if (ne01 >= 20 || (ne00 % 32 != 0)) {
5081+
if (!ggml_metal_flash_attn_ext_use_vec(dst)) {
50735082
// half8x8 kernel
50745083
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
50755084
const int64_t ncpsg = 64; // cache values per simdgroup !! sync with kernel template arguments !!
@@ -5294,14 +5303,9 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
52945303
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
52955304
GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31));
52965305

5297-
const int32_t nrows = ne1*ne2*ne3;
5298-
5299-
// [TAG_METAL_EXTRA_SIZE_OP_MUL_MAT_ID]
5306+
// [TAG_METAL_EXTRA_SIZE_OP_FLASH_ATTN_EXT]
53005307
const size_t offs_tmp = offs_dst + ggml_nbytes(dst);
53015308

5302-
//printf("ne01 = %d, ne02 = %d, ne03 = %d, ne20 = %d\n", ne01, ne02, ne03, ne20);
5303-
//printf("needed memory: %.3f MiB\n", (float) (ne01*ne02*ne03*ne20*sizeof(float))/1024.0f/1024.0f);
5304-
53055309
[encoder setBuffer:id_dst offset:offs_tmp atIndex:6];
53065310

53075311
[encoder setThreadgroupMemoryLength:smem atIndex:0];
@@ -5312,6 +5316,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
53125316

53135317
// reduce the results from the workgroups
53145318
{
5319+
const int32_t nrows = ne1*ne2*ne3;
5320+
53155321
ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = {
53165322
nrows,
53175323
};
@@ -6150,30 +6156,31 @@ static size_t ggml_backend_metal_buffer_type_shared_get_max_size(ggml_backend_bu
61506156
static size_t ggml_backend_metal_buffer_type_shared_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
61516157
size_t res = ggml_nbytes(tensor);
61526158

6159+
// some operations require additional memory for fleeting data:
61536160
switch (tensor->op) {
61546161
case GGML_OP_MUL_MAT_ID:
61556162
{
6156-
// [TAG_METAL_EXTRA_SIZE_OP_MUL_MAT_ID]
6157-
const int64_t ne02 = tensor->src[0]->ne[2];
6158-
const int64_t ne21 = tensor->src[2]->ne[1];
6163+
const int64_t ne02 = tensor->src[0]->ne[2]; // n_expert
6164+
const int64_t ne21 = tensor->src[2]->ne[1]; // n_token
61596165

6166+
// [TAG_METAL_EXTRA_SIZE_OP_MUL_MAT_ID]
61606167
res += ggml_type_size(GGML_TYPE_I32)*ne02;
6161-
res += ggml_type_size(GGML_TYPE_I32)*ne21*ne02;
6168+
res += ggml_type_size(GGML_TYPE_I32)*ne02*ne21;
61626169
} break;
61636170
case GGML_OP_FLASH_ATTN_EXT:
61646171
{
6165-
// [TAG_METAL_EXTRA_SIZE_OP_FLASH_ATTN_EXT]
6166-
const int64_t nwg = 32;
6172+
if (ggml_metal_flash_attn_ext_use_vec(tensor)) {
6173+
const int64_t nwg = 32;
61676174

6168-
const int64_t ne01 = tensor->src[0]->ne[1];
6169-
const int64_t ne02 = tensor->src[0]->ne[2];
6170-
const int64_t ne03 = tensor->src[0]->ne[3];
6171-
const int64_t ne20 = tensor->src[2]->ne[0];
6175+
const int64_t ne01 = tensor->src[0]->ne[1];
6176+
const int64_t ne02 = tensor->src[0]->ne[2];
6177+
const int64_t ne03 = tensor->src[0]->ne[3];
6178+
const int64_t ne20 = tensor->src[2]->ne[0];
61726179

6173-
if (ne01 < 20) {
61746180
// temp buffer for writing the results from each workgroup
6175-
// - ne20: the size of the head vector
6181+
// - ne20: the size of the Value head
61766182
// - + 2: the S and M values for each intermediate result
6183+
// [TAG_METAL_EXTRA_SIZE_OP_FLASH_ATTN_EXT]
61776184
res += ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
61786185
}
61796186
} break;

0 commit comments

Comments
 (0)