We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 0882355 commit 524d5bcCopy full SHA for 524d5bc
lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py
@@ -88,7 +88,11 @@ def test_decode_attentions(
88
for _ in range(test_count):
89
q_nope = torch.randn(q_nope_shape, device="cuda", dtype=dtype) / 10
90
q_rope = torch.randn(q_rope_shape, device="cuda", dtype=dtype) / 10
91
- kv_buffer_shape = [test_seq_len + 10, kv_nope_shape[1], kv_nope_shape[2] + kv_rope_shape[2]]
+ kv_buffer_shape = [
92
+ (test_seq_len + 10) * infer_state.batch_size,
93
+ kv_nope_shape[1],
94
+ kv_nope_shape[2] + kv_rope_shape[2],
95
+ ]
96
kv_buffer = torch.randn(kv_buffer_shape, device="cuda", dtype=dtype) / 10
97
98
kv_nope = kv_buffer[:, :, 0 : kv_nope_shape[2]]
0 commit comments