@@ -33,11 +33,11 @@ def page_table_copy(
3333 page_table , # destination tensor [batch, seq]
3434 req_to_token_indexs , # source tensor [batch, seq]
3535 b_req_idx , # request index to copy from
36- max_seq_len_k , # sequence length to copy
3736):
3837 assert page_table .dim () == 2 , "page_table should be 2D"
3938 assert req_to_token_indexs .dim () == 2 , "req_to_token_indexs should be 2D"
4039
40+ max_seq_len_k = page_table .shape [1 ]
4141 batch_size = page_table .size (0 )
4242 BLOCK_SIZE = 128
4343
@@ -57,10 +57,9 @@ def page_table_copy(
5757 )
5858
5959
60- import torch
61-
62-
6360def test_page_table_copy ():
61+ import torch
62+
6463 batch_size , seq_len = 2 , 8
6564
6665 req_to_token_indexs = torch .arange (batch_size * seq_len , dtype = torch .int32 ).reshape (batch_size , seq_len ).cuda ()
@@ -69,9 +68,8 @@ def test_page_table_copy():
6968
7069 b_req_idx = torch .tensor ([0 , 2 , 1 , 3 ], dtype = torch .int32 , device = "cuda" )[::2 ]
7170 print (b_req_idx .stride ())
72- max_seq_len_k = seq_len
7371
74- page_table_copy (page_table , req_to_token_indexs , b_req_idx , max_seq_len_k )
72+ page_table_copy (page_table , req_to_token_indexs , b_req_idx )
7573
7674 print ("req_to_token_indexs:" )
7775 print (req_to_token_indexs .cpu ().numpy ())
0 commit comments