Skip to content

Commit 99dd942

Browse files
author
wangzaijun
committed
fix page_table_copy
1 parent e986ffb commit 99dd942

File tree

2 files changed

+4
-7
lines changed

2 files changed

+4
-7
lines changed

lightllm/common/basemodel/triton_kernel/fa3_utils.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ def page_table_copy(
3333
page_table, # destination tensor [batch, seq]
3434
req_to_token_indexs, # source tensor [batch, seq]
3535
b_req_idx, # request index to copy from
36-
max_seq_len_k, # sequence length to copy
3736
):
3837
assert page_table.dim() == 2, "page_table should be 2D"
3938
assert req_to_token_indexs.dim() == 2, "req_to_token_indexs should be 2D"
4039

40+
max_seq_len_k = page_table.shape[1]
4141
batch_size = page_table.size(0)
4242
BLOCK_SIZE = 128
4343

@@ -57,10 +57,9 @@ def page_table_copy(
5757
)
5858

5959

60-
import torch
61-
62-
6360
def test_page_table_copy():
61+
import torch
62+
6463
batch_size, seq_len = 2, 8
6564

6665
req_to_token_indexs = torch.arange(batch_size * seq_len, dtype=torch.int32).reshape(batch_size, seq_len).cuda()
@@ -69,9 +68,8 @@ def test_page_table_copy():
6968

7069
b_req_idx = torch.tensor([0, 2, 1, 3], dtype=torch.int32, device="cuda")[::2]
7170
print(b_req_idx.stride())
72-
max_seq_len_k = seq_len
7371

74-
page_table_copy(page_table, req_to_token_indexs, b_req_idx, max_seq_len_k)
72+
page_table_copy(page_table, req_to_token_indexs, b_req_idx)
7573

7674
print("req_to_token_indexs:")
7775
print(req_to_token_indexs.cpu().numpy())

lightllm/models/deepseek2/flashattention_infer_struct.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
5757
page_table=self.page_table[:, :max_seq_len_k],
5858
req_to_token_indexs=model.req_manager.req_to_token_indexs,
5959
b_req_idx=self.b_req_idx[args_mtp_step :: (args_mtp_step + 1)],
60-
max_seq_len_k=max_seq_len_k,
6160
)
6261
if args_mtp_step > 0:
6362
self.b_att_seq_len = self.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous()

0 commit comments

Comments
 (0)