@@ -1975,7 +1975,9 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
19751975 const bool has_mask = op->src [3 ] != nullptr ;
19761976
19771977 if (ggml_metal_op_flash_attn_ext_use_vec (op)) {
1978- const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0 ;
1978+ // note: always reserve the padding space to avoid graph reallocations
1979+ // const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
1980+ const bool has_kvpad = true ;
19791981
19801982 if (has_kvpad) {
19811983 res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
@@ -1984,7 +1986,8 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
19841986 (has_mask ? ggml_type_size (GGML_TYPE_F16)*ne31*ne32*ne33 : 0 ));
19851987 }
19861988 } else {
1987- const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0 ;
1989+ // const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
1990+ const bool has_kvpad = true ;
19881991
19891992 if (has_kvpad) {
19901993 res += OP_FLASH_ATTN_EXT_NCPSG*(
@@ -2020,9 +2023,10 @@ size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
20202023 const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec (op);
20212024
20222025 // this optimization is not useful for the vector kernels
2023- if (is_vec) {
2024- return res;
2025- }
2026+ // note: always reserve the blk buffer to avoid graph reallocations
2027+ // if (is_vec) {
2028+ // return res;
2029+ // }
20262030
20272031 const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
20282032 const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
@@ -2049,13 +2053,16 @@ size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
20492053
20502054 size_t res = 0 ;
20512055
2052- if (ggml_metal_op_flash_attn_ext_use_vec (op)) {
2056+ // note: always reserve the temp buffer to avoid graph reallocations
2057+ // if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
2058+ if (true ) {
20532059 const int64_t nwg = 32 ;
2060+ const int64_t ne01_max = std::min (ne01, 32 );
20542061
20552062 // temp buffer for writing the results from each workgroup
20562063 // - ne20: the size of the Value head
20572064 // - + 2: the S and M values for each intermediate result
2058- res += ggml_type_size (GGML_TYPE_F32)*(ne01 *ne02*ne03*nwg*(ne20 + 2 ));
2065+ res += ggml_type_size (GGML_TYPE_F32)*(ne01_max *ne02*ne03*nwg*(ne20 + 2 ));
20592066 }
20602067
20612068 return res;
@@ -3523,38 +3530,95 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
35233530 ggml_metal_library_t lib = ctx->lib ;
35243531 ggml_metal_encoder_t enc = ctx->enc ;
35253532
3533+ GGML_ASSERT (ggml_is_contiguous_rows (op->src [0 ]));
3534+
35263535 GGML_TENSOR_LOCALS ( int32_t , ne0, op->src [0 ], ne);
35273536 GGML_TENSOR_LOCALS (uint64_t , nb0, op->src [0 ], nb);
35283537 GGML_TENSOR_LOCALS ( int32_t , ne, op, ne);
35293538 GGML_TENSOR_LOCALS (uint32_t , nb, op, nb);
35303539
3540+ ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort (lib, op);
3541+
35313542 // bitonic sort requires the number of elements to be power of 2
3532- int64_t ne00_padded = 1 ;
3533- while (ne00_padded < ne00) {
3534- ne00_padded *= 2 ;
3543+ int nth = 1 ;
3544+ while (nth < ne00 && 2 *nth <= ggml_metal_pipeline_max_theads_per_threadgroup (pipeline) ) {
3545+ nth *= 2 ;
35353546 }
35363547
3537- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort (lib, op);
3538-
3539- const int64_t nrows = ggml_nrows (op->src [0 ]);
3548+ const int nptg = (ne00 + nth - 1 )/nth;
35403549
35413550 // Metal kernels require the buffer size to be multiple of 16 bytes
35423551 // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
3543- const size_t smem = GGML_PAD (ne00_padded*sizeof (int32_t ), 16 );
3552+ const size_t smem = GGML_PAD (nth*sizeof (int32_t ), 16 );
3553+
3554+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id (op->src [0 ]);
3555+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id (op);
3556+
3557+ ggml_metal_buffer_id bid_tmp = bid_dst;
3558+ bid_tmp.offs += ggml_nbytes (op);
3559+
3560+ if ((int ) ceil (std::log (nptg) / std::log (2 )) % 2 == 1 ) {
3561+ std::swap (bid_dst, bid_tmp);
3562+ }
35443563
35453564 ggml_metal_kargs_argsort args = {
3546- /* .ncols =*/ ne00,
3547- /* .ncols_pad =*/ ne00_padded
3565+ /* .ne00 =*/ ne00,
3566+ /* .ne01 =*/ ne01,
3567+ /* .ne02 =*/ ne02,
3568+ /* .ne03 =*/ ne03,
3569+ /* .nb00 =*/ nb00,
3570+ /* .nb01 =*/ nb01,
3571+ /* .nb02 =*/ nb02,
3572+ /* .nb03 =*/ nb03,
35483573 };
35493574
35503575 ggml_metal_encoder_set_pipeline (enc, pipeline);
35513576 ggml_metal_encoder_set_bytes (enc, &args, sizeof (args), 0 );
3552- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op-> src [ 0 ]) , 1 );
3553- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op), 2 );
3577+ ggml_metal_encoder_set_buffer (enc, bid_src0 , 1 );
3578+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2 );
35543579
35553580 ggml_metal_encoder_set_threadgroup_memory_size (enc, smem, 0 );
35563581
3557- ggml_metal_encoder_dispatch_threadgroups (enc, 1 , nrows, 1 , ne00_padded, 1 , 1 );
3582+ ggml_metal_encoder_dispatch_threadgroups (enc, nptg*ne01, ne02, ne03, nth, 1 , 1 );
3583+
3584+ ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge (lib, op);
3585+
3586+ int len = nth;
3587+
3588+ while (len < ne00) {
3589+ ggml_metal_op_concurrency_reset (ctx);
3590+
3591+ ggml_metal_kargs_argsort_merge args_merge = {
3592+ .ne00 = ne00,
3593+ .ne01 = ne01,
3594+ .ne02 = ne02,
3595+ .ne03 = ne03,
3596+ .nb00 = nb00,
3597+ .nb01 = nb01,
3598+ .nb02 = nb02,
3599+ .nb03 = nb03,
3600+ .len = len,
3601+ };
3602+
3603+ // merges per row
3604+ const int nm = (ne00 + 2 *len - 1 ) / (2 *len);
3605+
3606+ const int nth = std::min (512 , ggml_metal_pipeline_max_theads_per_threadgroup (pipeline_merge));
3607+
3608+ ggml_metal_encoder_set_pipeline (enc, pipeline_merge);
3609+ ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof (args_merge), 0 );
3610+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1 );
3611+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2 );
3612+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 3 );
3613+
3614+ ggml_metal_encoder_set_threadgroup_memory_size (enc, 0 , 0 );
3615+
3616+ ggml_metal_encoder_dispatch_threadgroups (enc, nm*ne01, ne02, ne03, nth, 1 , 1 );
3617+
3618+ std::swap (bid_dst, bid_tmp);
3619+
3620+ len <<= 1 ;
3621+ }
35583622
35593623 return 1 ;
35603624}
0 commit comments