@@ -1975,7 +1975,9 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
19751975 const bool has_mask = op->src [3 ] != nullptr ;
19761976
19771977 if (ggml_metal_op_flash_attn_ext_use_vec (op)) {
1978- const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0 ;
1978+ // note: always reserve the padding space to avoid graph reallocations
1979+ // const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
1980+ const bool has_kvpad = true ;
19791981
19801982 if (has_kvpad) {
19811983 res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
@@ -1984,7 +1986,8 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
19841986 (has_mask ? ggml_type_size (GGML_TYPE_F16)*ne31*ne32*ne33 : 0 ));
19851987 }
19861988 } else {
1987- const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0 ;
1989+ // const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
1990+ const bool has_kvpad = true ;
19881991
19891992 if (has_kvpad) {
19901993 res += OP_FLASH_ATTN_EXT_NCPSG*(
@@ -2020,9 +2023,10 @@ size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
20202023 const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec (op);
20212024
20222025 // this optimization is not useful for the vector kernels
2023- if (is_vec) {
2024- return res;
2025- }
2026+ // note: always reserve the blk buffer to avoid graph reallocations
2027+ // if (is_vec) {
2028+ // return res;
2029+ // }
20262030
20272031 const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
20282032 const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
@@ -2049,13 +2053,16 @@ size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
20492053
20502054 size_t res = 0 ;
20512055
2052- if (ggml_metal_op_flash_attn_ext_use_vec (op)) {
2056+ // note: always reserve the temp buffer to avoid graph reallocations
2057+ // if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
2058+ if (true ) {
20532059 const int64_t nwg = 32 ;
2060+ const int64_t ne01_max = std::min (ne01, 32 );
20542061
20552062 // temp buffer for writing the results from each workgroup
20562063 // - ne20: the size of the Value head
20572064 // - + 2: the S and M values for each intermediate result
2058- res += ggml_type_size (GGML_TYPE_F32)*(ne01 *ne02*ne03*nwg*(ne20 + 2 ));
2065+ res += ggml_type_size (GGML_TYPE_F32)*(ne01_max *ne02*ne03*nwg*(ne20 + 2 ));
20592066 }
20602067
20612068 return res;
0 commit comments