Skip to content

Commit 9724ea9

Browse files
ikawrakowIwan Kawrakow
andauthored
Attention mask tweaks for better long context performance (ikawrakow#825)
* Parallelize mask We see non-negligible PP gains for long contexts. More importantly, the strange drop in performance observed for GPT-OSS for context >= 32k tokens is gone. * Whith FA on, create mask as f16 directly * WIP * Reduce KQ mask padding to 16 Why was it 64 in the first place? I don't observe any issues, while TG performance for long contexts improves by 2-4%. --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 1db0c49 commit 9724ea9

File tree

3 files changed

+277
-25
lines changed

3 files changed

+277
-25
lines changed

ggml/include/ggml.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2235,7 +2235,7 @@ extern "C" {
22352235
int min_entries,
22362236
float thresh);
22372237

2238-
#define GGML_KQ_MASK_PAD 64
2238+
#define GGML_KQ_MASK_PAD 16
22392239

22402240
// q: [n_embd, n_batch, n_head, 1]
22412241
// k: [n_embd, n_kv, n_head_kv, 1]

src/llama-build-context.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,12 @@ ggml_tensor * llm_build_context::build_inp_out_ids() {
276276
}
277277

278278
ggml_tensor * llm_build_context::build_inp_KQ_mask(bool causal) {
279+
if (causal && flash_attn) {
280+
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
281+
cb(lctx.inp_KQ_mask, "KQ_mask", -1);
282+
ggml_set_input(lctx.inp_KQ_mask);
283+
return lctx.inp_KQ_mask;
284+
}
279285
lctx.inp_KQ_mask = causal
280286
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
281287
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
@@ -287,6 +293,12 @@ ggml_tensor * llm_build_context::build_inp_KQ_mask(bool causal) {
287293

288294
ggml_tensor * llm_build_context::build_inp_KQ_mask_swa(bool causal) {
289295
GGML_ASSERT(hparams.n_swa > 0);
296+
if (causal && flash_attn) {
297+
lctx.inp_KQ_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
298+
cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1);
299+
ggml_set_input(lctx.inp_KQ_mask_swa);
300+
return lctx.inp_KQ_mask_swa;
301+
}
290302

291303
lctx.inp_KQ_mask_swa = causal
292304
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))

0 commit comments

Comments
 (0)