@@ -5583,7 +5583,7 @@ static int ggml_metal_encode_node(
55835583 // half4x4 kernel
55845584 const int64_t nqptg = 1 ; // queries per threadgroup !! sync with kernel template arguments !!
55855585 const int64_t ncpsg = 32 ; // cache values per simdgroup !! sync with kernel template arguments !!
5586- const int64_t nkpsg = 1 *ncpsg;
5586+ const int64_t nkpsg = 1 *ncpsg; // TODO: make adjustable
55875587
55885588 GGML_ASSERT (nqptg <= 32 );
55895589 GGML_ASSERT (nqptg % 1 == 0 );
@@ -5602,6 +5602,7 @@ static int ggml_metal_encode_node(
56025602 int64_t nsgmax = 2 ;
56035603 while (true ) {
56045604 const size_t smem = FATTN_SMEM (nsgmax);
5605+ // avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes
56055606 if (smem > device.maxThreadgroupMemoryLength /2 ) {
56065607 break ;
56075608 }
@@ -5642,8 +5643,16 @@ static int ggml_metal_encode_node(
56425643 // printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
56435644 GGML_ASSERT (smem <= device.maxThreadgroupMemoryLength );
56445645
5645- // tokens per expert
5646- const size_t s_tmp = ggml_type_size (GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*ne20 + ne01*ne02*ne03*nwg*2 );
5646+ // sanity checks
5647+ GGML_ASSERT (ne01*ne02*ne03 == ne1*ne2*ne3);
5648+ GGML_ASSERT (ne1*ne2*ne3 <= (1u << 31 ));
5649+
5650+ const int32_t nrows = ne1*ne2*ne3;
5651+
5652+ // temp buffer for writing the results from each workgroup
5653+ // - ne20: the size of the head vector
5654+ // - + 2: the S and M values for each intermediate result
5655+ const size_t s_tmp = ggml_type_size (GGML_TYPE_F32)*(nrows*nwg*(ne20 + 2 ));
56475656 id <MTLBuffer > h_tmp = ggml_metal_mem_pool_alloc (mem_pool, s_tmp);
56485657 if (!h_tmp) {
56495658 GGML_LOG_ERROR (" %s : failed to allocate buffer from memory pool, size = %zu \n " , __func__, s_tmp);
@@ -5662,10 +5671,8 @@ static int ggml_metal_encode_node(
56625671 // reduce the results from the workgroups
56635672 {
56645673 ggml_metal_kargs_flash_attn_ext_reduce args0 = {
5674+ nrows,
56655675 ne20,
5666- ne1,
5667- ne2,
5668- ne3,
56695676 };
56705677
56715678 id <MTLComputePipelineState > pipeline0 = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE].pipeline ;
@@ -5676,7 +5683,7 @@ static int ggml_metal_encode_node(
56765683 [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
56775684
56785685 // printf("ne1 = %d, ne2 = %d, ne3 = %d, ne20 = %d\n", ne1, ne2, ne3, ne20);
5679- [encoder dispatchThreadgroups: MTLSizeMake (( uint64_t ) ne1*ne2*ne3 , 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (32 *32 , 1 , 1 )];
5686+ [encoder dispatchThreadgroups: MTLSizeMake (nrows , 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (32 *32 , 1 , 1 )];
56805687 }
56815688 }
56825689#undef FATTN_SMEM
0 commit comments