Skip to content

Commit 6305eb7

Browse files
committed
fix
1 parent e341ec6 commit 6305eb7

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

src/llama.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -591,12 +591,21 @@ static struct ggml_tensor * llm_build_kqv(
591591

592592
struct ggml_tensor * padded_v = v;
593593
int64_t n_embd_head_v_out = n_embd_head_v;
594+
// 确保正确填充特征维度(假设v的特征维度是ne[2])
594595
if (n_embd_head_v < n_embd_head_k) {
595-
padded_v = ggml_pad(ctx, v, 0, k->ne[0] - v->ne[1], 0, 0);
596+
padded_v = ggml_pad(ctx, v,
597+
0, // 不填充dim 0
598+
0, // 不填充dim 1
599+
n_embd_head_k - n_embd_head_v, // 填充特征维度dim 2
600+
0);
596601
cb(padded_v, "padded_v", il);
597602
n_embd_head_v_out = n_embd_head_k;
598603
}
599604

605+
// 确保Flash Attention输入维度对齐
606+
GGML_ASSERT(padded_v->ne[2] == k->ne[2]); // 特征维度一致
607+
GGML_ASSERT(q->ne[1] == k->ne[1]); // 序列长度一致
608+
600609
cur = ggml_flash_attn_ext(ctx, q, k, padded_v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
601610
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
602611

0 commit comments

Comments
 (0)