Skip to content

Commit d770a02

Browse files
authored
fix
1 parent f87b3ba commit d770a02

File tree

3 files changed

+21
-8
lines changed

3 files changed

+21
-8
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ def __init__(self, kvargs):
7676
self._verify_must()
7777
self._verify_params()
7878
self._init_quant()
79-
self._init_inferstate_cls()
8079

8180
# 更连续的显存分配可以有更好的性能
8281
if self.max_total_token_num is None:
@@ -92,6 +91,7 @@ def __init__(self, kvargs):
9291
self._init_infer_layer()
9392
self._init_some_value()
9493
self._init_custom()
94+
self._init_inferstate_cls()
9595
self._init_cudagraph()
9696
self._check_max_len_infer()
9797
torch.cuda.empty_cache()

lightllm/models/llama/flashinfer_struct.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,15 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
2323
if not self.is_prefill:
2424
if get_env_start_args().enable_flashinfer_decode:
2525
self.kv_last_page_len_buffer = torch.full((self.batch_size,), 1, dtype=torch.int32).to(input_ids.device)
26-
self.kv_indices = torch.empty(
27-
self.batch_size * self.flashinfer_extra_state.max_seq_length, dtype=torch.int32
28-
).to(input_ids.device)
26+
if self.batch_size <= model.graph_max_batch_size:
27+
self.kv_indices = self.flashinfer_extra_state.kv_indices_buffer[self.microbatch_index][
28+
: self.batch_size * self.flashinfer_extra_state.max_seq_length
29+
]
30+
else:
31+
self.kv_indices = torch.empty(
32+
self.batch_size * self.flashinfer_extra_state.max_seq_length, dtype=torch.int32
33+
).to(input_ids.device)
34+
2935
repack_kv_index(
3036
self.req_manager.req_to_token_indexs,
3137
self.b_req_idx,

lightllm/models/llama/model.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@ def __init__(self, model):
3030
self.head_dim = model.config["hidden_size"] // model.config["num_attention_heads"]
3131
self.workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8).to(get_current_device_id())
3232
self.max_seq_length = model.max_seq_length
33+
self.kv_indices_buffer = [
34+
torch.empty(model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32).to(
35+
get_current_device_id()
36+
),
37+
torch.empty(model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32).to(
38+
get_current_device_id()
39+
),
40+
]
3341
self.q_data_type = model.data_type
3442
self.kv_data_type = model.data_type
3543

@@ -51,8 +59,6 @@ def __init__(self, kvargs):
5159
self.enable_flashinfer = (
5260
get_env_start_args().enable_flashinfer_prefill or get_env_start_args().enable_flashinfer_decode
5361
)
54-
if self.enable_flashinfer:
55-
self.infer_state_class = LlamaFlashInferStateInfo
5662
super().__init__(kvargs)
5763
return
5864

@@ -61,8 +67,6 @@ def _init_config(self):
6167
# rename key
6268
# repair_config()
6369
self._reset_num_key_value_heads()
64-
if self.enable_flashinfer:
65-
self.flashinfer_extra_state = LlamaFlashInferStateExtraInfo(self)
6670
return
6771

6872
def _reset_num_key_value_heads(self):
@@ -90,6 +94,9 @@ def _init_mem_manager(self):
9094
def _init_inferstate_cls(self):
9195
if get_env_start_args().enable_fa3:
9296
self.infer_state_class = FlashAttentionStateInfo
97+
elif self.enable_flashinfer:
98+
self.infer_state_class = LlamaFlashInferStateInfo
99+
self.flashinfer_extra_state = LlamaFlashInferStateExtraInfo(self)
93100

94101
def _init_custom(self):
95102
"""

0 commit comments

Comments
 (0)