Skip to content

Commit 7ca7b0d

Browse files
author
niushengxiao
committed
fix: remove use_dynamic_prompt_cache code in flashinfer_struct.py
1 parent 42e8199 commit 7ca7b0d

File tree

2 files changed

+38
-58
lines changed

2 files changed

+38
-58
lines changed

lightllm/models/llama/flashinfer_struct.py

Lines changed: 33 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -68,55 +68,39 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
6868
q_starts = self.b1_cu_q_seq_len.int()
6969
kv_starts = self.b1_cu_kv_seq_len.int()
7070
kv_last_page_len = torch.full((self.batch_size,), 1, dtype=torch.int32).to(input_ids.device)
71-
if self.use_dynamic_prompt_cache:
72-
kv_indices = torch.empty(
73-
self.batch_size * self.flashinfer_extra_state.max_seq_length, dtype=torch.int32
74-
).to(input_ids.device)
75-
repack_kv_index(
76-
self.req_manager.req_to_token_indexs,
77-
self.b_req_idx,
78-
self.b_seq_len,
79-
self.b_start_loc,
80-
self.max_len_in_batch,
81-
kv_indices,
82-
)
83-
self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(
84-
self.flashinfer_extra_state.workspace_buffer,
85-
qo_indptr_buf=q_starts,
86-
paged_kv_indptr_buf=kv_starts,
87-
paged_kv_indices_buf=kv_indices,
88-
paged_kv_last_page_len_buf=kv_last_page_len,
89-
)
90-
self.prefill_wrapper.plan(
91-
q_starts,
92-
kv_starts,
93-
kv_indices,
94-
kv_last_page_len,
95-
self.flashinfer_extra_state.tp_q_head_num,
96-
self.flashinfer_extra_state.tp_kv_head_num,
97-
self.flashinfer_extra_state.head_dim,
98-
1,
99-
causal=True,
100-
pos_encoding_mode="NONE",
101-
logits_soft_cap=0.0,
102-
q_data_type=self.flashinfer_extra_state.q_data_type,
103-
kv_data_type=self.flashinfer_extra_state.kv_data_type,
104-
)
105-
else:
106-
self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(
107-
self.flashinfer_extra_state.workspace_buffer,
108-
)
109-
self.prefill_wrapper.plan(
110-
qo_indptr=q_starts,
111-
kv_indptr=kv_starts,
112-
num_qo_heads=self.flashinfer_extra_state.tp_q_head_num,
113-
num_kv_heads=self.flashinfer_extra_state.tp_kv_head_num,
114-
head_dim_qk=self.flashinfer_extra_state.head_dim,
115-
head_dim_vo=self.flashinfer_extra_state.head_dim,
116-
causal=True,
117-
q_data_type=self.flashinfer_extra_state.q_data_type,
118-
kv_data_type=self.flashinfer_extra_state.kv_data_type,
119-
)
71+
kv_indices = torch.empty(
72+
self.batch_size * self.flashinfer_extra_state.max_seq_length, dtype=torch.int32
73+
).to(input_ids.device)
74+
repack_kv_index(
75+
self.req_manager.req_to_token_indexs,
76+
self.b_req_idx,
77+
self.b_seq_len,
78+
self.b_start_loc,
79+
self.max_len_in_batch,
80+
kv_indices,
81+
)
82+
self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(
83+
self.flashinfer_extra_state.workspace_buffer,
84+
qo_indptr_buf=q_starts,
85+
paged_kv_indptr_buf=kv_starts,
86+
paged_kv_indices_buf=kv_indices,
87+
paged_kv_last_page_len_buf=kv_last_page_len,
88+
)
89+
self.prefill_wrapper.plan(
90+
q_starts,
91+
kv_starts,
92+
kv_indices,
93+
kv_last_page_len,
94+
self.flashinfer_extra_state.tp_q_head_num,
95+
self.flashinfer_extra_state.tp_kv_head_num,
96+
self.flashinfer_extra_state.head_dim,
97+
1,
98+
causal=True,
99+
pos_encoding_mode="NONE",
100+
logits_soft_cap=0.0,
101+
q_data_type=self.flashinfer_extra_state.q_data_type,
102+
kv_data_type=self.flashinfer_extra_state.kv_data_type,
103+
)
120104
return
121105

122106
def copy_for_cuda_graph(self, new_infer_state):

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -237,15 +237,11 @@ def init_model(self, kvargs):
237237
raise e
238238

239239
set_random_seed(2147483647)
240-
self.radix_cache = (
241-
RadixCache(
242-
get_unique_server_name(),
243-
self.model.mem_manager.size,
244-
self.rank_in_node,
245-
mem_manager=self.model.mem_manager,
246-
)
247-
if self.use_dynamic_prompt_cache
248-
else None
240+
self.radix_cache = RadixCache(
241+
get_unique_server_name(),
242+
self.model.mem_manager.size,
243+
self.rank_in_node,
244+
mem_manager=self.model.mem_manager,
249245
)
250246

251247
if "prompt_cache_kv_buffer" in model_cfg:

0 commit comments

Comments
 (0)