Skip to content

Commit 2606b0a

Browse files
authored
metal : make the FA extra sizes consistent (#17143)
1 parent 307772f commit 2606b0a

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1975,7 +1975,9 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
19751975
const bool has_mask = op->src[3] != nullptr;
19761976

19771977
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
1978-
const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
1978+
// note: always reserve the padding space to avoid graph reallocations
1979+
//const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
1980+
const bool has_kvpad = true;
19791981

19801982
if (has_kvpad) {
19811983
res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
@@ -1984,7 +1986,8 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
19841986
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
19851987
}
19861988
} else {
1987-
const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
1989+
//const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
1990+
const bool has_kvpad = true;
19881991

19891992
if (has_kvpad) {
19901993
res += OP_FLASH_ATTN_EXT_NCPSG*(
@@ -2020,9 +2023,10 @@ size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
20202023
const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op);
20212024

20222025
// this optimization is not useful for the vector kernels
2023-
if (is_vec) {
2024-
return res;
2025-
}
2026+
// note: always reserve the blk buffer to avoid graph reallocations
2027+
//if (is_vec) {
2028+
// return res;
2029+
//}
20262030

20272031
const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
20282032
const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
@@ -2049,13 +2053,16 @@ size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
20492053

20502054
size_t res = 0;
20512055

2052-
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
2056+
// note: always reserve the temp buffer to avoid graph reallocations
2057+
//if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
2058+
if (true) {
20532059
const int64_t nwg = 32;
2060+
const int64_t ne01_max = std::min(ne01, 32);
20542061

20552062
// temp buffer for writing the results from each workgroup
20562063
// - ne20: the size of the Value head
20572064
// - + 2: the S and M values for each intermediate result
2058-
res += ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
2065+
res += ggml_type_size(GGML_TYPE_F32)*(ne01_max*ne02*ne03*nwg*(ne20 + 2));
20592066
}
20602067

20612068
return res;

0 commit comments

Comments
 (0)