Skip to content

Commit 524d5bc

Browse files
author
wangzaijun
committed
fix
1 parent 0882355 commit 524d5bc

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,11 @@ def test_decode_attentions(
8888
for _ in range(test_count):
8989
q_nope = torch.randn(q_nope_shape, device="cuda", dtype=dtype) / 10
9090
q_rope = torch.randn(q_rope_shape, device="cuda", dtype=dtype) / 10
91-
kv_buffer_shape = [test_seq_len + 10, kv_nope_shape[1], kv_nope_shape[2] + kv_rope_shape[2]]
91+
kv_buffer_shape = [
92+
(test_seq_len + 10) * infer_state.batch_size,
93+
kv_nope_shape[1],
94+
kv_nope_shape[2] + kv_rope_shape[2],
95+
]
9296
kv_buffer = torch.randn(kv_buffer_shape, device="cuda", dtype=dtype) / 10
9397

9498
kv_nope = kv_buffer[:, :, 0 : kv_nope_shape[2]]

0 commit comments

Comments
 (0)