Skip to content

Commit a561c63

Browse files
author
niushengxiao
committed
fix: fix a bug in flashinfer
1 parent a4bc0d6 commit a561c63

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

lightllm/models/deepseek2/flashinfer_struct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
3838
self.b_req_idx,
3939
self.b_seq_len,
4040
self.b_start_loc,
41-
self.max_len_in_batch,
41+
self.max_kv_seq_len,
4242
self.kv_indices,
4343
)
4444
if self.decode_wrapper is None:

lightllm/models/llama/flashinfer_struct.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
4141
self.b_req_idx,
4242
self.b_seq_len,
4343
self.b_start_loc,
44-
self.max_len_in_batch,
44+
self.max_kv_seq_len,
4545
self.kv_indices,
4646
)
4747
self.kv_starts = self.b1_cu_kv_seq_len.int()
@@ -81,8 +81,8 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
8181
self.req_manager.req_to_token_indexs,
8282
self.b_req_idx,
8383
self.b_seq_len,
84-
kv_starts,
85-
self.max_len_in_batch,
84+
kv_starts[:-1],
85+
self.max_kv_seq_len,
8686
kv_indices,
8787
)
8888
self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(

0 commit comments

Comments
 (0)