@@ -54,13 +54,16 @@ class _FlashInferPlanner:
5454 decode_wrapper : Optional [flashinfer .BatchDecodeWithPagedKVCacheWrapper ]
5555 cached_decode_wrappers : Dict [PlanParams , flashinfer .BatchDecodeWithPagedKVCacheWrapper ]
5656 plan_params : Optional [PlanParams ]
57+ plan_params_prefill : Optional [PlanParams ]
58+ plan_params_decode : Optional [PlanParams ]
5759
5860 def __init__ (self ):
5961 self .workspace_buffer = None
6062 self .prefill_wrapper = None
6163 self .decode_wrapper = None
6264 self .cached_decode_wrappers = {}
63- self .plan_params = None
65+ self .plan_params_prefill = None
66+ self .plan_params_decode = None
6467
6568 def _init_decode_wrapper (self ):
6669 assert self .workspace_buffer is not None
@@ -82,11 +85,45 @@ def init_workspace(self, workspace_buffer: torch.Tensor):
8285 self .decode_wrapper = self ._init_decode_wrapper ()
8386
8487 def reset (self ) -> None :
85- self .plan_params = None
88+ self .plan_params_prefill = None
89+ self .plan_params_decode = None
8690
87- def plan (
91+ def plan_prefill (
92+ self ,
93+ qo_indptr_host : torch .Tensor ,
94+ kv_page_indptr_host : torch .Tensor ,
95+ kv_page_indices : torch .Tensor ,
96+ kv_last_page_len_host : torch .Tensor ,
97+ kv_lens_arr_host : torch .Tensor ,
98+ seq_len_host : torch .Tensor ,
99+ plan_params : PlanParams ,
100+ ) -> None :
101+ # check for re-planning
102+ if plan_params != self .plan_params_prefill :
103+ # plan prefill
104+ self .prefill_wrapper .plan (
105+ qo_indptr_host ,
106+ kv_page_indptr_host ,
107+ kv_page_indices ,
108+ kv_last_page_len_host ,
109+ plan_params .n_heads , # Q heads
110+ plan_params .n_kv_heads , # KV heads
111+ plan_params .head_dim ,
112+ plan_params .page_size ,
113+ causal = plan_params .causal ,
114+ q_data_type = plan_params .q_dtype ,
115+ kv_data_type = plan_params .kv_dtype ,
116+ sm_scale = plan_params .sm_scale ,
117+ # max_token_per_sequence=max(seq_len_host).item(),
118+ seq_lens = kv_lens_arr_host ,
119+ )
120+ self .plan_params_prefill = plan_params
121+
122+ # return prefill wrapper
123+ return self .prefill_wrapper
124+
125+ def plan_decode (
88126 self ,
89- qo_indptr : torch .Tensor ,
90127 kv_page_indptr : torch .Tensor ,
91128 kv_page_indices : torch .Tensor ,
92129 kv_last_page_len : torch .Tensor ,
@@ -126,29 +163,12 @@ def _plan_decode(wrapper: flashinfer.BatchDecodeWithPagedKVCacheWrapper):
126163 return wrapper
127164
128165 # check for re-planning
129- if plan_params != self .plan_params :
130- if plan_params .is_generate :
131- _plan_decode (self .decode_wrapper )
132- else :
133- # plan prefill
134- self .prefill_wrapper .plan (
135- qo_indptr ,
136- kv_page_indptr ,
137- kv_page_indices ,
138- kv_last_page_len ,
139- plan_params .n_heads , # Q heads
140- plan_params .n_kv_heads , # KV heads
141- plan_params .head_dim ,
142- plan_params .page_size ,
143- causal = plan_params .causal ,
144- q_data_type = plan_params .q_dtype ,
145- kv_data_type = plan_params .kv_dtype ,
146- sm_scale = plan_params .sm_scale ,
147- )
148- self .plan_params = plan_params
149-
150- # return desired wrapper
151- return self .decode_wrapper if plan_params .is_generate else self .prefill_wrapper
166+ if plan_params != self .plan_params_decode :
167+ _plan_decode (self .decode_wrapper )
168+ self .plan_params_decode = plan_params
169+
170+ # return decode wrapper
171+ return self .decode_wrapper
152172
153173
154174_GlobalFlashInferPlanner = _FlashInferPlanner ()
@@ -211,10 +231,14 @@ def flashinfer_mha_with_cache(
211231 v : torch .Tensor ,
212232 # STANDARD METADATA
213233 batch_info_host : torch .Tensor ,
214- cu_seqlen : torch .Tensor ,
234+ cu_seqlen_host : torch .Tensor ,
215235 cu_num_pages : torch .Tensor ,
236+ cu_num_pages_host : torch .Tensor ,
216237 cache_loc : torch .Tensor ,
217238 last_page_len : torch .Tensor ,
239+ last_page_len_host : torch .Tensor ,
240+ seq_len_with_cache_host : torch .Tensor ,
241+ seq_len_host : torch .Tensor ,
218242 # EXTRA METADATA
219243 flashinfer_batch_indices : torch .Tensor ,
220244 flashinfer_positions : torch .Tensor ,
@@ -240,30 +264,11 @@ def flashinfer_mha_with_cache(
240264 # convert to flashinfer-style metadata
241265 num_prefill , num_prefill_tokens , num_decode = batch_info_host .tolist ()
242266 num_seq = num_prefill + num_decode
243-
244- qo_indptr = cu_seqlen [: num_seq + 1 ]
245- paged_kv_indptr = cu_num_pages [: num_seq + 1 ]
246-
247- # NOTE: it is okay to have cache_loc here without truncation. paged_kv_indptr will be
248- # truncated and will point to the correct sub range of cache_loc.
249- paged_kv_indices = cache_loc
250- paged_kv_last_page_len = last_page_len [:num_seq ]
267+ num_total_tokens = num_prefill_tokens + num_decode
251268
252269 n_heads = q .shape [1 ]
253270 n_kv_heads = k .shape [1 ]
254271
255- pp = PlanParams (
256- n_heads = n_heads ,
257- n_kv_heads = n_kv_heads ,
258- head_dim = head_dim ,
259- num_seq = len (qo_indptr ) - 1 ,
260- is_generate = (s == 1 ),
261- page_size = k_cache .shape [1 ],
262- q_dtype = q .dtype ,
263- kv_dtype = k_cache .dtype ,
264- sm_scale = scale ,
265- )
266-
267272 # Assuming k_scale = v_scale = 1.0
268273 k_scale , v_scale = 1.0 , 1.0
269274 # k = (k / k_scale).to(torch.float8_e4m3fn) if k_scale != 1.0, same for v
@@ -272,28 +277,94 @@ def flashinfer_mha_with_cache(
272277 v = v .to (torch .float8_e4m3fn )
273278
274279 flashinfer .page .append_paged_kv_cache (
275- k ,
276- v ,
277- flashinfer_batch_indices ,
278- flashinfer_positions ,
279- (k_cache , v_cache ),
280- paged_kv_indices ,
281- paged_kv_indptr ,
282- paged_kv_last_page_len ,
280+ append_key = k ,
281+ append_value = v ,
282+ batch_indices = flashinfer_batch_indices ,
283+ positions = flashinfer_positions ,
284+ paged_kv_cache = (k_cache , v_cache ),
285+ kv_indices = cache_loc ,
286+ kv_indptr = cu_num_pages [: num_seq + 1 ] ,
287+ kv_last_page_len = last_page_len [: num_seq ] ,
283288 )
284289
285- # run the flashinfer planner and obtain the correct wrapper
286- wrapper = _GlobalFlashInferPlanner .plan (
287- qo_indptr ,
288- paged_kv_indptr ,
289- paged_kv_indices ,
290- paged_kv_last_page_len ,
291- pp ,
292- )
290+ # check if we need to re-combine outputs
291+ if num_prefill > 0 and num_decode > 0 :
292+ y = torch .empty_like (q )
293+ else :
294+ y = None
295+
296+ # now run split prefill, decode
297+ if num_prefill > 0 :
298+ q_prefill = q [:num_prefill_tokens ]
299+
300+ pp_prefill = PlanParams (
301+ n_heads = n_heads ,
302+ n_kv_heads = n_kv_heads ,
303+ head_dim = head_dim ,
304+ num_seq = num_prefill ,
305+ is_generate = False ,
306+ page_size = k_cache .shape [1 ],
307+ q_dtype = q_prefill .dtype ,
308+ kv_dtype = k_cache .dtype ,
309+ sm_scale = scale ,
310+ )
293311
294- y = wrapper .run (
295- q , (k_cache , v_cache ), k_scale = k_scale , v_scale = v_scale , enable_pdl = get_env_enable_pdl ()
296- )
312+ wrapper_prefill = _GlobalFlashInferPlanner .plan_prefill (
313+ qo_indptr_host = cu_seqlen_host [: num_prefill + 1 ],
314+ kv_page_indptr_host = cu_num_pages_host [: num_prefill + 1 ],
315+ kv_page_indices = cache_loc ,
316+ kv_last_page_len_host = last_page_len_host [:num_prefill ],
317+ kv_lens_arr_host = seq_len_with_cache_host [:num_prefill ],
318+ seq_len_host = seq_len_host [:num_prefill ],
319+ plan_params = pp_prefill ,
320+ )
321+
322+ y_prefill = wrapper_prefill .run (
323+ q_prefill ,
324+ (k_cache , v_cache ),
325+ k_scale = k_scale ,
326+ v_scale = v_scale ,
327+ enable_pdl = get_env_enable_pdl (),
328+ )
329+ if y is not None :
330+ y [:num_prefill_tokens ] = y_prefill
331+ else :
332+ y = y_prefill
333+
334+ if num_decode > 0 :
335+ q_decode = q [num_prefill_tokens :num_total_tokens ]
336+
337+ pp_decode = PlanParams (
338+ n_heads = n_heads ,
339+ n_kv_heads = n_kv_heads ,
340+ head_dim = head_dim ,
341+ num_seq = num_decode ,
342+ is_generate = True ,
343+ page_size = k_cache .shape [1 ],
344+ q_dtype = q_decode .dtype ,
345+ kv_dtype = k_cache .dtype ,
346+ sm_scale = scale ,
347+ )
348+
349+ # run the flashinfer planner and obtain the correct wrapper
350+ wrapper_decode = _GlobalFlashInferPlanner .plan_decode (
351+ kv_page_indptr = cu_num_pages [num_prefill : num_seq + 1 ],
352+ kv_page_indices = cache_loc ,
353+ kv_last_page_len = last_page_len [num_prefill :num_seq ],
354+ plan_params = pp_decode ,
355+ )
356+
357+ y_decode = wrapper_decode .run (
358+ q_decode ,
359+ (k_cache , v_cache ),
360+ k_scale = k_scale ,
361+ v_scale = v_scale ,
362+ enable_pdl = get_env_enable_pdl (),
363+ )
364+ if y is not None :
365+ y [num_prefill_tokens :num_total_tokens ] = y_decode
366+ else :
367+ y = y_decode
297368
298369 return y .view (q_shape_og ) # [b,s,n*h_d] or [b,s, n, h_d]
299370
@@ -306,10 +377,14 @@ def flashinfer_mha_with_cache_fake(
306377 v : torch .Tensor ,
307378 # STANDARD METADATA
308379 batch_info_host : torch .Tensor ,
309- cu_seqlen : torch .Tensor ,
380+ cu_seqlen_host : torch .Tensor ,
310381 cu_num_pages : torch .Tensor ,
382+ cu_num_pages_host : torch .Tensor ,
311383 cache_loc : torch .Tensor ,
312384 last_page_len : torch .Tensor ,
385+ last_page_len_host : torch .Tensor ,
386+ seq_len_with_cache_host : torch .Tensor ,
387+ seq_len_host : torch .Tensor ,
313388 # EXTRA METADATA
314389 flashinfer_batch_indices : torch .Tensor ,
315390 flashinfer_positions : torch .Tensor ,
@@ -358,7 +433,17 @@ def get_cached_attention_op(cls) -> MHACallable:
358433
359434 @classmethod
360435 def get_standard_metadata_args (cls ) -> List [str ]:
361- return ["batch_info_host" , "cu_seqlen" , "cu_num_pages" , "cache_loc" , "last_page_len" ]
436+ return [
437+ "batch_info_host" ,
438+ "cu_seqlen_host" ,
439+ "cu_num_pages" ,
440+ "cu_num_pages_host" ,
441+ "cache_loc" ,
442+ "last_page_len" ,
443+ "last_page_len_host" ,
444+ "seq_len_with_cache_host" ,
445+ "seq_len_host" ,
446+ ]
362447
363448 @classmethod
364449 def get_prepare_extra_metadata_info (
0 commit comments