@@ -2521,7 +2521,6 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
25212521 /* .nb02 =*/ nb02,
25222522 /* .nb11 =*/ nb11,
25232523 /* .nb21 =*/ nb21,
2524-
25252524 };
25262525
25272526 [encoder setComputePipelineState: pipeline];
@@ -3166,54 +3165,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
31663165 const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
31673166 const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
31683167
3169- // use this branch to test the ggml_metal_mem_pool functionality
3170- #if 0
3171- // cpy to tmp buffer in MTLHeap
3172-
3173- id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
3174- if (!h_src0) {
3175- GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
3176- return 0;
3177- }
3178-
3179- offs_src0 = 0;
3180-
3181- ggml_metal_kargs_cpy args_cpy = {
3182- /*.ne00 =*/ ne00,
3183- /*.ne01 =*/ ne01,
3184- /*.ne02 =*/ ne02,
3185- /*.ne03 =*/ ne03,
3186- /*.nb00 =*/ nb00,
3187- /*.nb01 =*/ nb01,
3188- /*.nb02 =*/ nb02,
3189- /*.nb03 =*/ nb03,
3190- /*.ne0 =*/ ne00,
3191- /*.ne1 =*/ ne01,
3192- /*.ne2 =*/ ne02,
3193- /*.ne3 =*/ ne03,
3194- /*.nb0 =*/ nb00,
3195- /*.nb1 =*/ nb01,
3196- /*.nb2 =*/ nb02,
3197- /*.nb3 =*/ nb03,
3198- };
3199-
3200- if (src0->type == GGML_TYPE_F16) {
3201- [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline];
3202- } else {
3203- [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline];
3204- }
3205- [encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0];
3206- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3207- [encoder setBuffer:h_src0 offset:0 atIndex:2];
3208-
3209- GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
3210- int nth_cpy = MIN(1024, ne00 / ggml_blck_size(src0->type));
3211-
3212- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)];
3213-
3214- #else
32153168 id <MTLBuffer > h_src0 = id_src0;
3216- # endif
3169+
32173170 // softmax
32183171
32193172 ggml_metal_kargs_soft_max args = {
@@ -4092,28 +4045,9 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
40924045 default : break ;
40934046 }
40944047
4095- // TODO: using mem pool allocations with enabled concurrency is not safe because the mem pool
4096- // reuses buffers. this can result in 2 concurrent MUL_MAT_ID ops using the same mem pool buffer.
4097- // so we add this extra barrier to prevent the race.
4098- // the correct solution is to remove mem pools and then remove this barrier [TAG_MEM_POOL_REMOVE]
4099- ggml_metal_encode_concurrency_reset (ctx_enc);
4100-
4101- // tokens per expert
4102- const size_t s_tpe = ggml_type_size (GGML_TYPE_I32)*ne02;
4103- id <MTLBuffer > h_tpe = ggml_metal_mem_pool_alloc (mem_pool, s_tpe);
4104- if (!h_tpe) {
4105- GGML_LOG_ERROR (" %s : failed to allocate buffer from memory pool, size = %zu \n " , __func__, s_tpe);
4106- return 0 ;
4107- }
4108-
4109- // id map
4110- // [n_tokens, n_expert]
4111- const size_t s_ids = ggml_type_size (GGML_TYPE_I32)*ne21*ne02;
4112- id <MTLBuffer > h_ids = ggml_metal_mem_pool_alloc (mem_pool, s_ids);
4113- if (!h_ids) {
4114- GGML_LOG_ERROR (" %s : failed to allocate buffer from memory pool, size = %zu \n " , __func__, s_ids);
4115- return 0 ;
4116- }
4048+ // [TAG_METAL_EXTRA_SIZE_OP_MUL_MAT_ID]
4049+ size_t offs_tpe = offs_dst + ggml_nbytes (dst);
4050+ size_t offs_ids = offs_tpe + ggml_type_size (GGML_TYPE_I32)*ne02;
41174051
41184052 {
41194053 ggml_metal_kargs_mul_mm_id_map0 args = {
@@ -4151,8 +4085,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
41514085 [encoder setComputePipelineState: pipeline];
41524086 [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
41534087 [encoder setBuffer: id_src2 offset: offs_src2 atIndex: 1 ];
4154- [encoder setBuffer: h_tpe offset: 0 atIndex: 2 ];
4155- [encoder setBuffer: h_ids offset: 0 atIndex: 3 ];
4088+ [encoder setBuffer: id_dst offset: offs_tpe atIndex: 2 ];
4089+ [encoder setBuffer: id_dst offset: offs_ids atIndex: 3 ];
41564090 [encoder setThreadgroupMemoryLength: smem atIndex: 0 ];
41574091
41584092 [encoder dispatchThreadgroups: MTLSizeMake (1 , 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (ne02, 1 , 1 )];
@@ -4214,8 +4148,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
42144148 [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
42154149 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
42164150 [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 2 ];
4217- [encoder setBuffer: h_tpe offset: 0 atIndex: 3 ];
4218- [encoder setBuffer: h_ids offset: 0 atIndex: 4 ];
4151+ [encoder setBuffer: id_dst offset: offs_tpe atIndex: 3 ];
4152+ [encoder setBuffer: id_dst offset: offs_ids atIndex: 4 ];
42194153 [encoder setBuffer: id_dst offset: offs_dst atIndex: 5 ];
42204154
42214155 [encoder setThreadgroupMemoryLength: 8192 atIndex: 0 ];
@@ -5306,6 +5240,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
53065240 GGML_ASSERT (ne01 < 65536 );
53075241
53085242 // use non-vec kernel if the batch size is large or if the vec-kernel is not supported for this head size
5243+ // [TAG_METAL_EXTRA_SIZE_OP_FLASH_ATTN_EXT]
53095244 if (ne01 >= 20 || (ne00 % 32 != 0 )) {
53105245 // half8x8 kernel
53115246 const int64_t nqptg = 8 ; // queries per threadgroup !! sync with kernel template arguments !!
@@ -5531,30 +5466,20 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
55315466 GGML_ASSERT (ne01*ne02*ne03 == ne1*ne2*ne3);
55325467 GGML_ASSERT (ne1*ne2*ne3 <= (1u << 31 ));
55335468
5534- // using mem pool allocations with enabled concurrency is not safe [TAG_MEM_POOL_REMOVE]
5535- // still, we assume that concurrent FA won't happen before we do the refactor
5536- // ggml_metal_encode_concurrency_reset(ctx_enc);
5537-
55385469 const int32_t nrows = ne1*ne2*ne3;
55395470
5540- // temp buffer for writing the results from each workgroup
5541- // - ne20: the size of the head vector
5542- // - + 2: the S and M values for each intermediate result
5543- const size_t s_tmp = ggml_type_size (GGML_TYPE_F32)*(nrows*nwg*(ne20 + 2 ));
5544- id <MTLBuffer > h_tmp = ggml_metal_mem_pool_alloc (mem_pool, s_tmp);
5545- if (!h_tmp) {
5546- GGML_LOG_ERROR (" %s : failed to allocate buffer from memory pool, size = %zu \n " , __func__, s_tmp);
5547- return 0 ;
5548- }
5471+ // [TAG_METAL_EXTRA_SIZE_OP_MUL_MAT_ID]
5472+ const size_t offs_tmp = offs_dst + ggml_nbytes (dst);
55495473
55505474 // printf("ne01 = %d, ne02 = %d, ne03 = %d, ne20 = %d\n", ne01, ne02, ne03, ne20);
55515475 // printf("needed memory: %.3f MiB\n", (float) (ne01*ne02*ne03*ne20*sizeof(float))/1024.0f/1024.0f);
55525476
5553- [encoder setBuffer: h_tmp offset: 0 atIndex: 6 ];
5477+ [encoder setBuffer: id_dst offset: offs_tmp atIndex: 6 ];
55545478
55555479 [encoder setThreadgroupMemoryLength: smem atIndex: 0 ];
55565480 [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + nqptg - 1 )/nqptg, ne02, ne03*nwg) threadsPerThreadgroup: MTLSizeMake (32 , nsg, 1 )];
55575481
5482+ // sync the 2 kernels
55585483 ggml_metal_encode_concurrency_reset (ctx_enc);
55595484
55605485 // reduce the results from the workgroups
@@ -5567,7 +5492,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
55675492
55685493 [encoder setComputePipelineState: pipeline0];
55695494 [encoder setBytes: &args0 length: sizeof (args0) atIndex: 0 ];
5570- [encoder setBuffer: h_tmp offset: 0 atIndex: 1 ];
5495+ [encoder setBuffer: id_dst offset: offs_tmp atIndex: 1 ];
55715496 [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
55725497
55735498 // printf("ne1 = %d, ne2 = %d, ne3 = %d, ne20 = %d\n", ne1, ne2, ne3, ne20);
@@ -6400,6 +6325,45 @@ static size_t ggml_backend_metal_buffer_type_shared_get_max_size(ggml_backend_bu
64006325 return max_size;
64016326}
64026327
6328+ static size_t ggml_backend_metal_buffer_type_shared_get_alloc_size (ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
6329+ size_t res = ggml_nbytes (tensor);
6330+
6331+ switch (tensor->op ) {
6332+ case GGML_OP_MUL_MAT_ID:
6333+ {
6334+ // [TAG_METAL_EXTRA_SIZE_OP_MUL_MAT_ID]
6335+ const int64_t ne02 = tensor->src [0 ]->ne [2 ];
6336+ const int64_t ne21 = tensor->src [2 ]->ne [1 ];
6337+
6338+ res += ggml_type_size (GGML_TYPE_I32)*ne02;
6339+ res += ggml_type_size (GGML_TYPE_I32)*ne21*ne02;
6340+ } break ;
6341+ case GGML_OP_FLASH_ATTN_EXT:
6342+ {
6343+ // [TAG_METAL_EXTRA_SIZE_OP_FLASH_ATTN_EXT]
6344+ const int64_t nwg = 32 ;
6345+
6346+ const int64_t ne01 = tensor->src [0 ]->ne [1 ];
6347+ const int64_t ne02 = tensor->src [0 ]->ne [2 ];
6348+ const int64_t ne03 = tensor->src [0 ]->ne [3 ];
6349+ const int64_t ne20 = tensor->src [2 ]->ne [0 ];
6350+
6351+ if (ne01 < 20 ) {
6352+ // temp buffer for writing the results from each workgroup
6353+ // - ne20: the size of the head vector
6354+ // - + 2: the S and M values for each intermediate result
6355+ res += ggml_type_size (GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2 ));
6356+ }
6357+ } break ;
6358+ default :
6359+ break ;
6360+ }
6361+
6362+ return res;
6363+
6364+ GGML_UNUSED (buft);
6365+ }
6366+
64036367static bool ggml_backend_metal_buffer_type_shared_is_host (ggml_backend_buffer_type_t buft) {
64046368 return false ;
64056369
@@ -6413,7 +6377,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_shared(void) {
64136377 /* .alloc_buffer = */ ggml_backend_metal_buffer_type_shared_alloc_buffer,
64146378 /* .get_alignment = */ ggml_backend_metal_buffer_type_shared_get_alignment,
64156379 /* .get_max_size = */ ggml_backend_metal_buffer_type_shared_get_max_size,
6416- /* .get_alloc_size = */ NULL , // defaults to ggml_nbytes
6380+ /* .get_alloc_size = */ ggml_backend_metal_buffer_type_shared_get_alloc_size,
64176381 /* .is_host = */ ggml_backend_metal_buffer_type_shared_is_host,
64186382 },
64196383 /* .device = */ &g_ggml_backend_metal_device,
0 commit comments