@@ -205,14 +205,16 @@ class _FlashInferPlanner:
205205 cached_cuda_graph_decode_wrappers : Dict [
206206 PlanParams , flashinfer .BatchDecodeWithPagedKVCacheWrapper
207207 ]
208- plan_params : Optional [PlanParams ]
208+ plan_params_prefill : Optional [PlanParams ]
209+ plan_params_decode : Optional [PlanParams ]
209210
210211 def __init__ (self ):
211212 self .workspace_buffer = None
212213 self .prefill_wrapper = None
213214 self .decode_wrapper = None
214215 self .cached_cuda_graph_decode_wrappers = {}
215- self .plan_params = None
216+ self .plan_params_prefill = None
217+ self .plan_params_decode = None
216218
217219 def _init_decode_wrapper (
218220 self ,
@@ -253,7 +255,8 @@ def init_workspace(self, workspace_buffer: torch.Tensor):
253255 self .decode_wrapper = self ._init_decode_wrapper ()
254256
255257 def reset (self ) -> None :
256- self .plan_params = None
258+ self .plan_params_prefill = None
259+ self .plan_params_decode = None
257260
258261 def plan_generate_only (
259262 self ,
@@ -279,9 +282,42 @@ def plan_generate_only(
279282 sm_scale = plan_params .sm_scale ,
280283 )
281284
282- def plan (
285+ def plan_prefill (
286+ self ,
287+ qo_indptr_host : torch .Tensor ,
288+ kv_page_indptr_host : torch .Tensor ,
289+ kv_page_indices : torch .Tensor ,
290+ kv_last_page_len_host : torch .Tensor ,
291+ kv_lens_arr_host : torch .Tensor ,
292+ seq_len_host : torch .Tensor ,
293+ plan_params : PlanParams ,
294+ ) -> None :
295+ # check for re-planning
296+ if plan_params != self .plan_params_prefill :
297+ # plan prefill
298+ self .prefill_wrapper .plan (
299+ qo_indptr_host ,
300+ kv_page_indptr_host ,
301+ kv_page_indices ,
302+ kv_last_page_len_host ,
303+ plan_params .n_heads , # Q heads
304+ plan_params .n_kv_heads , # KV heads
305+ plan_params .head_dim ,
306+ plan_params .page_size ,
307+ causal = plan_params .causal ,
308+ q_data_type = plan_params .q_dtype ,
309+ kv_data_type = plan_params .kv_dtype ,
310+ sm_scale = plan_params .sm_scale ,
311+ # max_token_per_sequence=max(seq_len_host).item(),
312+ seq_lens = kv_lens_arr_host ,
313+ )
314+ self .plan_params_prefill = plan_params
315+
316+ # return prefill wrapper
317+ return self .prefill_wrapper
318+
319+ def plan_decode (
283320 self ,
284- qo_indptr : torch .Tensor ,
285321 kv_page_indptr : torch .Tensor ,
286322 kv_page_indices : torch .Tensor ,
287323 kv_last_page_len : torch .Tensor ,
@@ -328,29 +364,12 @@ def _plan_decode(
328364 return wrapper
329365
330366 # check for re-planning
331- if plan_params != self .plan_params :
332- if plan_params .is_generate :
333- _plan_decode (self .decode_wrapper )
334- else :
335- # plan prefill
336- self .prefill_wrapper .plan (
337- qo_indptr ,
338- kv_page_indptr ,
339- kv_page_indices ,
340- kv_last_page_len ,
341- plan_params .n_heads , # Q heads
342- plan_params .n_kv_heads , # KV heads
343- plan_params .head_dim ,
344- plan_params .page_size ,
345- causal = plan_params .causal ,
346- q_data_type = plan_params .q_dtype ,
347- kv_data_type = plan_params .kv_dtype ,
348- sm_scale = plan_params .sm_scale ,
349- )
350- self .plan_params = plan_params
367+ if plan_params != self .plan_params_decode :
368+ _plan_decode (self .decode_wrapper )
369+ self .plan_params_decode = plan_params
351370
352- # return desired wrapper
353- return self .decode_wrapper if plan_params . is_generate else self . prefill_wrapper
371+ # return decode wrapper
372+ return self .decode_wrapper
354373
355374
356375_GlobalFlashInferPlanner = _FlashInferPlanner ()
@@ -412,10 +431,14 @@ def flashinfer_mha_with_cache(
412431 v : torch .Tensor ,
413432 # STANDARD METADATA
414433 batch_info_host : torch .Tensor ,
415- cu_seqlen : torch .Tensor ,
434+ cu_seqlen_host : torch .Tensor ,
416435 cu_num_pages : torch .Tensor ,
436+ cu_num_pages_host : torch .Tensor ,
417437 cache_loc : torch .Tensor ,
418438 last_page_len : torch .Tensor ,
439+ last_page_len_host : torch .Tensor ,
440+ seq_len_with_cache_host : torch .Tensor ,
441+ seq_len_host : torch .Tensor ,
419442 # EXTRA METADATA
420443 flashinfer_batch_indices : torch .Tensor ,
421444 flashinfer_positions : torch .Tensor ,
@@ -441,30 +464,11 @@ def flashinfer_mha_with_cache(
441464 # convert to flashinfer-style metadata
442465 num_prefill , num_prefill_tokens , num_decode = batch_info_host .tolist ()
443466 num_seq = num_prefill + num_decode
444-
445- qo_indptr = cu_seqlen [: num_seq + 1 ]
446- paged_kv_indptr = cu_num_pages [: num_seq + 1 ]
447-
448- # NOTE: it is okay to have cache_loc here without truncation. paged_kv_indptr will be
449- # truncated and will point to the correct sub range of cache_loc.
450- paged_kv_indices = cache_loc
451- paged_kv_last_page_len = last_page_len [:num_seq ]
467+ num_total_tokens = num_prefill_tokens + num_decode
452468
453469 n_heads = q .shape [1 ]
454470 n_kv_heads = k .shape [1 ]
455471
456- pp = PlanParams (
457- n_heads = n_heads ,
458- n_kv_heads = n_kv_heads ,
459- head_dim = head_dim ,
460- num_seq = len (qo_indptr ) - 1 ,
461- is_generate = (s == 1 ),
462- page_size = k_cache .shape [1 ],
463- q_dtype = q .dtype ,
464- kv_dtype = k_cache .dtype ,
465- sm_scale = scale ,
466- )
467-
468472 # Assuming k_scale = v_scale = 1.0
469473 k_scale , v_scale = 1.0 , 1.0
470474 # k = (k / k_scale).to(torch.float8_e4m3fn) if k_scale != 1.0, same for v
@@ -473,28 +477,94 @@ def flashinfer_mha_with_cache(
473477 v = v .to (torch .float8_e4m3fn )
474478
475479 flashinfer .page .append_paged_kv_cache (
476- k ,
477- v ,
478- flashinfer_batch_indices ,
479- flashinfer_positions ,
480- (k_cache , v_cache ),
481- paged_kv_indices ,
482- paged_kv_indptr ,
483- paged_kv_last_page_len ,
480+ append_key = k ,
481+ append_value = v ,
482+ batch_indices = flashinfer_batch_indices ,
483+ positions = flashinfer_positions ,
484+ paged_kv_cache = (k_cache , v_cache ),
485+ kv_indices = cache_loc ,
486+ kv_indptr = cu_num_pages [: num_seq + 1 ] ,
487+ kv_last_page_len = last_page_len [: num_seq ] ,
484488 )
485489
486- # run the flashinfer planner and obtain the correct wrapper
487- wrapper = _GlobalFlashInferPlanner .plan (
488- qo_indptr ,
489- paged_kv_indptr ,
490- paged_kv_indices ,
491- paged_kv_last_page_len ,
492- pp ,
493- )
490+ # check if we need to re-combine outputs
491+ if num_prefill > 0 and num_decode > 0 :
492+ y = torch .empty_like (q )
493+ else :
494+ y = None
495+
496+ # now run split prefill, decode
497+ if num_prefill > 0 :
498+ q_prefill = q [:num_prefill_tokens ]
499+
500+ pp_prefill = PlanParams (
501+ n_heads = n_heads ,
502+ n_kv_heads = n_kv_heads ,
503+ head_dim = head_dim ,
504+ num_seq = num_prefill ,
505+ is_generate = False ,
506+ page_size = k_cache .shape [1 ],
507+ q_dtype = q_prefill .dtype ,
508+ kv_dtype = k_cache .dtype ,
509+ sm_scale = scale ,
510+ )
494511
495- y = wrapper .run (
496- q , (k_cache , v_cache ), k_scale = k_scale , v_scale = v_scale , enable_pdl = get_env_enable_pdl ()
497- )
512+ wrapper_prefill = _GlobalFlashInferPlanner .plan_prefill (
513+ qo_indptr_host = cu_seqlen_host [: num_prefill + 1 ],
514+ kv_page_indptr_host = cu_num_pages_host [: num_prefill + 1 ],
515+ kv_page_indices = cache_loc ,
516+ kv_last_page_len_host = last_page_len_host [:num_prefill ],
517+ kv_lens_arr_host = seq_len_with_cache_host [:num_prefill ],
518+ seq_len_host = seq_len_host [:num_prefill ],
519+ plan_params = pp_prefill ,
520+ )
521+
522+ y_prefill = wrapper_prefill .run (
523+ q_prefill ,
524+ (k_cache , v_cache ),
525+ k_scale = k_scale ,
526+ v_scale = v_scale ,
527+ enable_pdl = get_env_enable_pdl (),
528+ )
529+ if y is not None :
530+ y [:num_prefill_tokens ] = y_prefill
531+ else :
532+ y = y_prefill
533+
534+ if num_decode > 0 :
535+ q_decode = q [num_prefill_tokens :num_total_tokens ]
536+
537+ pp_decode = PlanParams (
538+ n_heads = n_heads ,
539+ n_kv_heads = n_kv_heads ,
540+ head_dim = head_dim ,
541+ num_seq = num_decode ,
542+ is_generate = True ,
543+ page_size = k_cache .shape [1 ],
544+ q_dtype = q_decode .dtype ,
545+ kv_dtype = k_cache .dtype ,
546+ sm_scale = scale ,
547+ )
548+
549+ # run the flashinfer planner and obtain the correct wrapper
550+ wrapper_decode = _GlobalFlashInferPlanner .plan_decode (
551+ kv_page_indptr = cu_num_pages [num_prefill : num_seq + 1 ],
552+ kv_page_indices = cache_loc ,
553+ kv_last_page_len = last_page_len [num_prefill :num_seq ],
554+ plan_params = pp_decode ,
555+ )
556+
557+ y_decode = wrapper_decode .run (
558+ q_decode ,
559+ (k_cache , v_cache ),
560+ k_scale = k_scale ,
561+ v_scale = v_scale ,
562+ enable_pdl = get_env_enable_pdl (),
563+ )
564+ if y is not None :
565+ y [num_prefill_tokens :num_total_tokens ] = y_decode
566+ else :
567+ y = y_decode
498568
499569 return y .view (q_shape_og ) # [b,s,n*h_d] or [b,s, n, h_d]
500570
@@ -507,10 +577,14 @@ def flashinfer_mha_with_cache_fake(
507577 v : torch .Tensor ,
508578 # STANDARD METADATA
509579 batch_info_host : torch .Tensor ,
510- cu_seqlen : torch .Tensor ,
580+ cu_seqlen_host : torch .Tensor ,
511581 cu_num_pages : torch .Tensor ,
582+ cu_num_pages_host : torch .Tensor ,
512583 cache_loc : torch .Tensor ,
513584 last_page_len : torch .Tensor ,
585+ last_page_len_host : torch .Tensor ,
586+ seq_len_with_cache_host : torch .Tensor ,
587+ seq_len_host : torch .Tensor ,
514588 # EXTRA METADATA
515589 flashinfer_batch_indices : torch .Tensor ,
516590 flashinfer_positions : torch .Tensor ,
@@ -559,7 +633,17 @@ def get_cached_attention_op(cls) -> MHACallable:
559633
560634 @classmethod
561635 def get_standard_metadata_args (cls ) -> List [str ]:
562- return ["batch_info_host" , "cu_seqlen" , "cu_num_pages" , "cache_loc" , "last_page_len" ]
636+ return [
637+ "batch_info_host" ,
638+ "cu_seqlen_host" ,
639+ "cu_num_pages" ,
640+ "cu_num_pages_host" ,
641+ "cache_loc" ,
642+ "last_page_len" ,
643+ "last_page_len_host" ,
644+ "seq_len_with_cache_host" ,
645+ "seq_len_host" ,
646+ ]
563647
564648 @classmethod
565649 def get_prepare_extra_metadata_info (
0 commit comments