@@ -68,55 +68,39 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
6868 q_starts = self .b1_cu_q_seq_len .int ()
6969 kv_starts = self .b1_cu_kv_seq_len .int ()
7070 kv_last_page_len = torch .full ((self .batch_size ,), 1 , dtype = torch .int32 ).to (input_ids .device )
71- if self .use_dynamic_prompt_cache :
72- kv_indices = torch .empty (
73- self .batch_size * self .flashinfer_extra_state .max_seq_length , dtype = torch .int32
74- ).to (input_ids .device )
75- repack_kv_index (
76- self .req_manager .req_to_token_indexs ,
77- self .b_req_idx ,
78- self .b_seq_len ,
79- self .b_start_loc ,
80- self .max_len_in_batch ,
81- kv_indices ,
82- )
83- self .prefill_wrapper = flashinfer .prefill .BatchPrefillWithPagedKVCacheWrapper (
84- self .flashinfer_extra_state .workspace_buffer ,
85- qo_indptr_buf = q_starts ,
86- paged_kv_indptr_buf = kv_starts ,
87- paged_kv_indices_buf = kv_indices ,
88- paged_kv_last_page_len_buf = kv_last_page_len ,
89- )
90- self .prefill_wrapper .plan (
91- q_starts ,
92- kv_starts ,
93- kv_indices ,
94- kv_last_page_len ,
95- self .flashinfer_extra_state .tp_q_head_num ,
96- self .flashinfer_extra_state .tp_kv_head_num ,
97- self .flashinfer_extra_state .head_dim ,
98- 1 ,
99- causal = True ,
100- pos_encoding_mode = "NONE" ,
101- logits_soft_cap = 0.0 ,
102- q_data_type = self .flashinfer_extra_state .q_data_type ,
103- kv_data_type = self .flashinfer_extra_state .kv_data_type ,
104- )
105- else :
106- self .prefill_wrapper = flashinfer .prefill .BatchPrefillWithRaggedKVCacheWrapper (
107- self .flashinfer_extra_state .workspace_buffer ,
108- )
109- self .prefill_wrapper .plan (
110- qo_indptr = q_starts ,
111- kv_indptr = kv_starts ,
112- num_qo_heads = self .flashinfer_extra_state .tp_q_head_num ,
113- num_kv_heads = self .flashinfer_extra_state .tp_kv_head_num ,
114- head_dim_qk = self .flashinfer_extra_state .head_dim ,
115- head_dim_vo = self .flashinfer_extra_state .head_dim ,
116- causal = True ,
117- q_data_type = self .flashinfer_extra_state .q_data_type ,
118- kv_data_type = self .flashinfer_extra_state .kv_data_type ,
119- )
71+ kv_indices = torch .empty (
72+ self .batch_size * self .flashinfer_extra_state .max_seq_length , dtype = torch .int32
73+ ).to (input_ids .device )
74+ repack_kv_index (
75+ self .req_manager .req_to_token_indexs ,
76+ self .b_req_idx ,
77+ self .b_seq_len ,
78+ self .b_start_loc ,
79+ self .max_len_in_batch ,
80+ kv_indices ,
81+ )
82+ self .prefill_wrapper = flashinfer .prefill .BatchPrefillWithPagedKVCacheWrapper (
83+ self .flashinfer_extra_state .workspace_buffer ,
84+ qo_indptr_buf = q_starts ,
85+ paged_kv_indptr_buf = kv_starts ,
86+ paged_kv_indices_buf = kv_indices ,
87+ paged_kv_last_page_len_buf = kv_last_page_len ,
88+ )
89+ self .prefill_wrapper .plan (
90+ q_starts ,
91+ kv_starts ,
92+ kv_indices ,
93+ kv_last_page_len ,
94+ self .flashinfer_extra_state .tp_q_head_num ,
95+ self .flashinfer_extra_state .tp_kv_head_num ,
96+ self .flashinfer_extra_state .head_dim ,
97+ 1 ,
98+ causal = True ,
99+ pos_encoding_mode = "NONE" ,
100+ logits_soft_cap = 0.0 ,
101+ q_data_type = self .flashinfer_extra_state .q_data_type ,
102+ kv_data_type = self .flashinfer_extra_state .kv_data_type ,
103+ )
120104 return
121105
122106 def copy_for_cuda_graph (self , new_infer_state ):
0 commit comments