Skip to content

Commit cfe86b2

Browse files
committed
cont : improve, comments
ggml-ci
1 parent 400e037 commit cfe86b2

File tree

1 file changed

+28
-21
lines changed

1 file changed

+28
-21
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,6 +1345,17 @@ @implementation GGMLMetalClass
13451345
return res;
13461346
}
13471347

1348+
// return true if we should use the FA vector kernel for this op
1349+
static bool ggml_metal_flash_attn_ext_use_vec(const struct ggml_tensor * op) {
1350+
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
1351+
1352+
const int64_t ne00 = op->ne[0]; // head size
1353+
const int64_t ne01 = op->ne[1]; // batch size
1354+
1355+
// use vec kernel if the batch size is small and if the head size is supported
1356+
return (ne01 < 20) && (ne00 % 32 == 0);
1357+
}
1358+
13481359
static id<MTLComputePipelineState> ggml_metal_get_pipeline_flash_attn_ext(
13491360
ggml_backend_t backend, struct ggml_tensor * op,
13501361
bool has_mask,
@@ -5066,9 +5077,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
50665077

50675078
GGML_ASSERT(ne01 < 65536);
50685079

5069-
// use non-vec kernel if the batch size is large or if the vec-kernel is not supported for this head size
5070-
// [TAG_METAL_EXTRA_SIZE_OP_FLASH_ATTN_EXT]
5071-
if (ne01 >= 20 || (ne00 % 32 != 0)) {
5080+
if (!ggml_metal_flash_attn_ext_use_vec(dst)) {
50725081
// half8x8 kernel
50735082
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
50745083
const int64_t ncpsg = 64; // cache values per simdgroup !! sync with kernel template arguments !!
@@ -5293,14 +5302,9 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
52935302
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
52945303
GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31));
52955304

5296-
const int32_t nrows = ne1*ne2*ne3;
5297-
5298-
// [TAG_METAL_EXTRA_SIZE_OP_MUL_MAT_ID]
5305+
// [TAG_METAL_EXTRA_SIZE_OP_FLASH_ATTN_EXT]
52995306
const size_t offs_tmp = offs_dst + ggml_nbytes(dst);
53005307

5301-
//printf("ne01 = %d, ne02 = %d, ne03 = %d, ne20 = %d\n", ne01, ne02, ne03, ne20);
5302-
//printf("needed memory: %.3f MiB\n", (float) (ne01*ne02*ne03*ne20*sizeof(float))/1024.0f/1024.0f);
5303-
53045308
[encoder setBuffer:id_dst offset:offs_tmp atIndex:6];
53055309

53065310
[encoder setThreadgroupMemoryLength:smem atIndex:0];
@@ -5311,6 +5315,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
53115315

53125316
// reduce the results from the workgroups
53135317
{
5318+
const int32_t nrows = ne1*ne2*ne3;
5319+
53145320
ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = {
53155321
nrows,
53165322
};
@@ -6149,30 +6155,31 @@ static size_t ggml_backend_metal_buffer_type_shared_get_max_size(ggml_backend_bu
61496155
static size_t ggml_backend_metal_buffer_type_shared_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
61506156
size_t res = ggml_nbytes(tensor);
61516157

6158+
// some operations require additional memory for fleeting data:
61526159
switch (tensor->op) {
61536160
case GGML_OP_MUL_MAT_ID:
61546161
{
6155-
// [TAG_METAL_EXTRA_SIZE_OP_MUL_MAT_ID]
6156-
const int64_t ne02 = tensor->src[0]->ne[2];
6157-
const int64_t ne21 = tensor->src[2]->ne[1];
6162+
const int64_t ne02 = tensor->src[0]->ne[2]; // n_expert
6163+
const int64_t ne21 = tensor->src[2]->ne[1]; // n_token
61586164

6165+
// [TAG_METAL_EXTRA_SIZE_OP_MUL_MAT_ID]
61596166
res += ggml_type_size(GGML_TYPE_I32)*ne02;
6160-
res += ggml_type_size(GGML_TYPE_I32)*ne21*ne02;
6167+
res += ggml_type_size(GGML_TYPE_I32)*ne02*ne21;
61616168
} break;
61626169
case GGML_OP_FLASH_ATTN_EXT:
61636170
{
6164-
// [TAG_METAL_EXTRA_SIZE_OP_FLASH_ATTN_EXT]
6165-
const int64_t nwg = 32;
6171+
if (ggml_metal_flash_attn_ext_use_vec(tensor)) {
6172+
const int64_t nwg = 32;
61666173

6167-
const int64_t ne01 = tensor->src[0]->ne[1];
6168-
const int64_t ne02 = tensor->src[0]->ne[2];
6169-
const int64_t ne03 = tensor->src[0]->ne[3];
6170-
const int64_t ne20 = tensor->src[2]->ne[0];
6174+
const int64_t ne01 = tensor->src[0]->ne[1];
6175+
const int64_t ne02 = tensor->src[0]->ne[2];
6176+
const int64_t ne03 = tensor->src[0]->ne[3];
6177+
const int64_t ne20 = tensor->src[2]->ne[0];
61716178

6172-
if (ne01 < 20) {
61736179
// temp buffer for writing the results from each workgroup
6174-
// - ne20: the size of the head vector
6180+
// - ne20: the size of the Value head
61756181
// - + 2: the S and M values for each intermediate result
6182+
// [TAG_METAL_EXTRA_SIZE_OP_FLASH_ATTN_EXT]
61766183
res += ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
61776184
}
61786185
} break;

0 commit comments

Comments
 (0)