2121 Constant ,
2222 MHACallable ,
2323 PrepareMetadataCallable ,
24+ PrepareMetadataHostCallable ,
2425 SequenceInfo ,
2526)
2627
@@ -183,7 +184,6 @@ class PlanParams:
183184 n_kv_heads : int
184185 head_dim : int
185186 num_seq : int
186- is_generate : bool
187187 page_size : int
188188 q_dtype : torch .dtype
189189 kv_dtype : torch .dtype
@@ -289,12 +289,17 @@ def plan_prefill(
289289 kv_page_indices : torch .Tensor ,
290290 kv_last_page_len_host : torch .Tensor ,
291291 kv_lens_arr_host : torch .Tensor ,
292- seq_len_host : torch .Tensor ,
293292 plan_params : PlanParams ,
294293 ) -> None :
295294 # check for re-planning
296295 if plan_params != self .plan_params_prefill :
297296 # plan prefill
297+ # NOTE (lucaslie): we use host versions here. the plan actually needs both (host+device)
298+ # version. Unfortunately, there is no good way to access the plan API and provide both
299+ # although we have both available. I have decided to use the host versions here to
300+ # ensure non-blocking invocation of plan, whereas the other way around would trigger a
301+ # blocking copy to cpu. This way we trigger a non-blocking copy to device (note that
302+ # this is safe since we do have pinned CPU memory for all our host-side arguments).
298303 self .prefill_wrapper .plan (
299304 qo_indptr_host ,
300305 kv_page_indptr_host ,
@@ -308,7 +313,6 @@ def plan_prefill(
308313 q_data_type = plan_params .q_dtype ,
309314 kv_data_type = plan_params .kv_dtype ,
310315 sm_scale = plan_params .sm_scale ,
311- # max_token_per_sequence=max(seq_len_host).item(),
312316 seq_lens = kv_lens_arr_host ,
313317 )
314318 self .plan_params_prefill = plan_params
@@ -359,7 +363,6 @@ def _plan_decode(
359363 _plan_decode (self .cached_cuda_graph_decode_wrappers [plan_params ])
360364 # check if we are in cuda graph capture and just return the pre-cached decode wrapper
361365 if torch .cuda .is_current_stream_capturing () or cuda_graph_state .in_warm_up ():
362- assert plan_params .is_generate , "Only generate is supported during cuda graph capture."
363366 wrapper = self .cached_cuda_graph_decode_wrappers [plan_params ]
364367 return wrapper
365368
@@ -423,6 +426,23 @@ def prepare_flashinfer_metadata_fake(
423426 )
424427
425428
429+ def prepare_flashinfer_metadata_host (
430+ batch_info_host : torch .Tensor ,
431+ cu_num_pages_host : torch .Tensor ,
432+ cache_loc_host : torch .Tensor ,
433+ last_page_len_host : torch .Tensor ,
434+ ) -> None :
435+ num_prefill , num_prefill_tokens , num_decode = batch_info_host .tolist ()
436+
437+ if num_prefill == 0 :
438+ _GlobalFlashInferPlanner .plan_generate_only (
439+ num_decode ,
440+ cu_num_pages_host [: num_decode + 1 ],
441+ cache_loc_host ,
442+ last_page_len_host [:num_decode ],
443+ )
444+
445+
426446@torch .library .custom_op ("auto_deploy::flashinfer_attention_mha_with_cache" , mutates_args = ())
427447def flashinfer_mha_with_cache (
428448 # Q, K, V
@@ -438,7 +458,6 @@ def flashinfer_mha_with_cache(
438458 last_page_len : torch .Tensor ,
439459 last_page_len_host : torch .Tensor ,
440460 seq_len_with_cache_host : torch .Tensor ,
441- seq_len_host : torch .Tensor ,
442461 # EXTRA METADATA
443462 flashinfer_batch_indices : torch .Tensor ,
444463 flashinfer_positions : torch .Tensor ,
@@ -502,7 +521,6 @@ def flashinfer_mha_with_cache(
502521 n_kv_heads = n_kv_heads ,
503522 head_dim = head_dim ,
504523 num_seq = num_prefill ,
505- is_generate = False ,
506524 page_size = k_cache .shape [1 ],
507525 q_dtype = q_prefill .dtype ,
508526 kv_dtype = k_cache .dtype ,
@@ -515,7 +533,6 @@ def flashinfer_mha_with_cache(
515533 kv_page_indices = cache_loc ,
516534 kv_last_page_len_host = last_page_len_host [:num_prefill ],
517535 kv_lens_arr_host = seq_len_with_cache_host [:num_prefill ],
518- seq_len_host = seq_len_host [:num_prefill ],
519536 plan_params = pp_prefill ,
520537 )
521538
@@ -539,7 +556,6 @@ def flashinfer_mha_with_cache(
539556 n_kv_heads = n_kv_heads ,
540557 head_dim = head_dim ,
541558 num_seq = num_decode ,
542- is_generate = True ,
543559 page_size = k_cache .shape [1 ],
544560 q_dtype = q_decode .dtype ,
545561 kv_dtype = k_cache .dtype ,
@@ -584,7 +600,6 @@ def flashinfer_mha_with_cache_fake(
584600 last_page_len : torch .Tensor ,
585601 last_page_len_host : torch .Tensor ,
586602 seq_len_with_cache_host : torch .Tensor ,
587- seq_len_host : torch .Tensor ,
588603 # EXTRA METADATA
589604 flashinfer_batch_indices : torch .Tensor ,
590605 flashinfer_positions : torch .Tensor ,
@@ -642,7 +657,6 @@ def get_standard_metadata_args(cls) -> List[str]:
642657 "last_page_len" ,
643658 "last_page_len_host" ,
644659 "seq_len_with_cache_host" ,
645- "seq_len_host" ,
646660 ]
647661
648662 @classmethod
@@ -684,18 +698,8 @@ def _init_workspace(si: SequenceInfo) -> torch.Tensor:
684698 return {"workspace_buffer" : _init_workspace }
685699
686700 @classmethod
687- def host_prepare_for_forward (cls , sequence_info : SequenceInfo ):
688- batch_info = sequence_info ._input_buffer .get_host_view ("batch_info" )
689- num_prefill , num_prefill_tokens , num_decode = batch_info .tolist ()
690- # Call plan for generate-only batches.
691- if num_prefill == 0 :
692- _GlobalFlashInferPlanner .plan_generate_only (
693- num_decode ,
694- sequence_info ._input_buffer .get_host_view ("cu_num_pages" )[: num_decode + 1 ],
695- sequence_info ._input_buffer .get_host_view ("cache_loc" ),
696- sequence_info ._input_buffer .get_host_view ("last_page_len" )[:num_decode ],
697- )
698- return
701+ def get_host_prepare_metadata_function (cls ) -> Optional [PrepareMetadataHostCallable ]:
702+ return prepare_flashinfer_metadata_host
699703
700704 @classmethod
701705 def get_constants (cls , source_attn_node : Node ) -> List [Constant ]:
0 commit comments