2525)
2626
2727
28+ # TODO: remove this when flashinfer version is updated to >0.5
29+ def fast_decode_plan (
30+ wrapper : flashinfer .BatchDecodeWithPagedKVCacheWrapper ,
31+ indptr : torch .Tensor ,
32+ indices : torch .Tensor ,
33+ last_page_len : torch .Tensor ,
34+ num_qo_heads : int ,
35+ num_kv_heads : int ,
36+ head_dim : int ,
37+ page_size : int ,
38+ pos_encoding_mode : str = "NONE" ,
39+ window_left : int = - 1 ,
40+ logits_soft_cap : Optional [float ] = None ,
41+ q_data_type : Optional [Union [str , torch .dtype ]] = None ,
42+ kv_data_type : Optional [Union [str , torch .dtype ]] = None ,
43+ data_type : Optional [Union [str , torch .dtype ]] = None ,
44+ sm_scale : Optional [float ] = None ,
45+ rope_scale : Optional [float ] = None ,
46+ rope_theta : Optional [float ] = None ,
47+ non_blocking : bool = True ,
48+ fixed_split_size : Optional [int ] = None ,
49+ disable_split_kv : bool = False ,
50+ global_override_indptr_cpu : Optional [torch .Tensor ] = None ,
51+ ) -> None :
52+ """
53+ Copied from flashinfer.decode.fast_decode_plan in flashinfer version >0.5.
54+ Does not exist in flashinfer version 0.3.1, hence copied here.
55+ """
56+ batch_size = len (last_page_len )
57+ if logits_soft_cap is None :
58+ logits_soft_cap = 0.0
59+
60+ # Handle data types consistently
61+ if data_type is not None :
62+ if q_data_type is None :
63+ q_data_type = data_type
64+ if kv_data_type is None :
65+ kv_data_type = data_type
66+ elif q_data_type is None :
67+ q_data_type = "float16"
68+
69+ if kv_data_type is None :
70+ kv_data_type = q_data_type
71+
72+ if wrapper .use_tensor_cores :
73+ qo_indptr_host = torch .arange (batch_size + 1 , dtype = torch .int32 , device = "cpu" )
74+ # Here we set fixed_split_size to -1 to avoid the assertion error in flashinfer's plan function
75+ if fixed_split_size is None :
76+ fixed_split_size = - 1
77+
78+ if wrapper .is_cuda_graph_enabled :
79+ if batch_size != wrapper ._fixed_batch_size :
80+ raise ValueError (
81+ "The batch size should be fixed in cudagraph mode, the runtime batch size {} "
82+ " mismatches the batch size set during initialization {}" .format (
83+ batch_size , wrapper ._fixed_batch_size
84+ )
85+ )
86+ if len (indices ) > len (wrapper ._paged_kv_indices_buf ):
87+ raise ValueError (
88+ "The size of indices should be less than or equal to the allocated buffer"
89+ )
90+ else :
91+ wrapper ._paged_kv_indptr_buf = indptr
92+ wrapper ._paged_kv_indices_buf = indices
93+ wrapper ._paged_kv_last_page_len_buf = last_page_len
94+ if wrapper .use_tensor_cores :
95+ wrapper ._qo_indptr_buf = qo_indptr_host .to (wrapper .device , non_blocking = non_blocking )
96+
97+ # Create empty tensors for dtype info if needed
98+ empty_q_data = torch .empty (
99+ 0 ,
100+ dtype = (getattr (torch , q_data_type ) if isinstance (q_data_type , str ) else q_data_type ),
101+ device = wrapper .device ,
102+ )
103+
104+ empty_kv_cache = torch .empty (
105+ 0 ,
106+ dtype = (getattr (torch , kv_data_type ) if isinstance (kv_data_type , str ) else kv_data_type ),
107+ device = wrapper .device ,
108+ )
109+
110+ indptr_host = (
111+ global_override_indptr_cpu if global_override_indptr_cpu is not None else indptr .cpu ()
112+ )
113+
114+ with torch .cuda .device (wrapper .device ):
115+ if wrapper .use_tensor_cores :
116+ # ALSO convert last_page_len to CPU
117+ if page_size == 1 :
118+ # When page size is 1, last_page_len is always 1.
119+ # Directly construct the host tensor rather than executing a device-to-host copy.
120+ last_page_len_host = torch .ones ((batch_size ,), dtype = torch .int32 , device = "cpu" )
121+ else :
122+ last_page_len_host = last_page_len .cpu ()
123+
124+ kv_lens_arr_host = flashinfer .get_seq_lens (indptr_host , last_page_len_host , page_size )
125+
126+ try :
127+ # Make sure we pass exactly 15 arguments for tensor core version
128+ wrapper ._plan_info = wrapper ._cached_module .plan (
129+ wrapper ._float_workspace_buffer ,
130+ wrapper ._int_workspace_buffer ,
131+ wrapper ._pin_memory_int_workspace_buffer ,
132+ qo_indptr_host ,
133+ indptr_host ,
134+ kv_lens_arr_host ,
135+ batch_size , # total_num_rows
136+ batch_size ,
137+ num_qo_heads ,
138+ num_kv_heads ,
139+ page_size ,
140+ wrapper .is_cuda_graph_enabled ,
141+ head_dim ,
142+ head_dim ,
143+ False , # causal
144+ )
145+ except Exception as e :
146+ raise RuntimeError (f"Error in standard plan: { e } " ) from e
147+ else :
148+ try :
149+ # Make sure we pass exactly 15 arguments for standard version
150+ wrapper ._plan_info = wrapper ._cached_module .plan (
151+ wrapper ._float_workspace_buffer ,
152+ wrapper ._int_workspace_buffer ,
153+ wrapper ._pin_memory_int_workspace_buffer ,
154+ indptr_host ,
155+ batch_size ,
156+ num_qo_heads ,
157+ num_kv_heads ,
158+ page_size ,
159+ wrapper .is_cuda_graph_enabled ,
160+ window_left ,
161+ logits_soft_cap ,
162+ head_dim ,
163+ head_dim ,
164+ empty_q_data ,
165+ empty_kv_cache ,
166+ )
167+ except Exception as e :
168+ raise RuntimeError (f"Error in standard plan: { e } " ) from e
169+
170+ wrapper ._pos_encoding_mode = pos_encoding_mode
171+ wrapper ._window_left = window_left
172+ wrapper ._logits_soft_cap = logits_soft_cap
173+ wrapper ._sm_scale = sm_scale
174+ wrapper ._rope_scale = rope_scale
175+ wrapper ._rope_theta = rope_theta
176+
177+
28178@dataclass
29179class PlanParams :
30180 """Parameters that affect the flashinfer execution plan."""
@@ -52,21 +202,42 @@ class _FlashInferPlanner:
52202 workspace_buffer : Optional [torch .Tensor ]
53203 prefill_wrapper : Optional [flashinfer .BatchPrefillWithPagedKVCacheWrapper ]
54204 decode_wrapper : Optional [flashinfer .BatchDecodeWithPagedKVCacheWrapper ]
55- cached_decode_wrappers : Dict [PlanParams , flashinfer .BatchDecodeWithPagedKVCacheWrapper ]
205+ cached_cuda_graph_decode_wrappers : Dict [
206+ PlanParams , flashinfer .BatchDecodeWithPagedKVCacheWrapper
207+ ]
56208 plan_params : Optional [PlanParams ]
57209
58210 def __init__ (self ):
59211 self .workspace_buffer = None
60212 self .prefill_wrapper = None
61213 self .decode_wrapper = None
62- self .cached_decode_wrappers = {}
214+ self .cached_cuda_graph_decode_wrappers = {}
63215 self .plan_params = None
64216
65- def _init_decode_wrapper (self ):
217+ def _init_decode_wrapper (
218+ self ,
219+ use_cuda_graph : bool = False ,
220+ indptr : Optional [torch .Tensor ] = None ,
221+ indices : Optional [torch .Tensor ] = None ,
222+ last_page_len : Optional [torch .Tensor ] = None ,
223+ ):
66224 assert self .workspace_buffer is not None
67- return flashinfer .BatchDecodeWithPagedKVCacheWrapper (
68- self .workspace_buffer , "NHD" , use_tensor_cores = True
69- )
225+ if use_cuda_graph :
226+ return flashinfer .BatchDecodeWithPagedKVCacheWrapper (
227+ self .workspace_buffer ,
228+ "NHD" ,
229+ use_cuda_graph = True ,
230+ paged_kv_indptr_buffer = indptr ,
231+ paged_kv_indices_buffer = indices ,
232+ paged_kv_last_page_len_buffer = last_page_len ,
233+ use_tensor_cores = True ,
234+ )
235+ else :
236+ return flashinfer .BatchDecodeWithPagedKVCacheWrapper (
237+ self .workspace_buffer ,
238+ "NHD" ,
239+ use_tensor_cores = True ,
240+ )
70241
71242 def init_workspace (self , workspace_buffer : torch .Tensor ):
72243 self .__init__ () # reset all state
@@ -84,6 +255,30 @@ def init_workspace(self, workspace_buffer: torch.Tensor):
84255 def reset (self ) -> None :
85256 self .plan_params = None
86257
258+ def plan_generate_only (
259+ self ,
260+ num_seq : int ,
261+ cu_num_pages : torch .Tensor ,
262+ cache_loc : torch .Tensor ,
263+ last_page_len : torch .Tensor ,
264+ ):
265+ for plan_params in self .cached_cuda_graph_decode_wrappers :
266+ if plan_params .num_seq == num_seq :
267+ wrapper = self .cached_cuda_graph_decode_wrappers [plan_params ]
268+ fast_decode_plan (
269+ wrapper ,
270+ cu_num_pages ,
271+ cache_loc ,
272+ last_page_len ,
273+ plan_params .n_heads ,
274+ plan_params .n_kv_heads ,
275+ plan_params .head_dim ,
276+ plan_params .page_size ,
277+ q_data_type = plan_params .q_dtype ,
278+ kv_data_type = plan_params .kv_dtype ,
279+ sm_scale = plan_params .sm_scale ,
280+ )
281+
87282 def plan (
88283 self ,
89284 qo_indptr : torch .Tensor ,
@@ -96,7 +291,9 @@ def plan(
96291 flashinfer .BatchDecodeWithPagedKVCacheWrapper ,
97292 ]:
98293 # plan decode helper function
99- def _plan_decode (wrapper : flashinfer .BatchDecodeWithPagedKVCacheWrapper ):
294+ def _plan_decode (
295+ wrapper : flashinfer .BatchDecodeWithPagedKVCacheWrapper ,
296+ ):
100297 wrapper .plan (
101298 kv_page_indptr ,
102299 kv_page_indices ,
@@ -111,18 +308,23 @@ def _plan_decode(wrapper: flashinfer.BatchDecodeWithPagedKVCacheWrapper):
111308 )
112309
113310 # we want to plan during warm-up of cuda graph capture to ensure we have the plan cached
114- if cuda_graph_state .in_warm_up () and plan_params not in self .cached_decode_wrappers :
115- self .cached_decode_wrappers [plan_params ] = self ._init_decode_wrapper ()
116- _plan_decode (self .cached_decode_wrappers [plan_params ])
117-
311+ if (
312+ cuda_graph_state .in_warm_up ()
313+ and plan_params not in self .cached_cuda_graph_decode_wrappers
314+ ):
315+ # During CUDA graph capture, the metadata tensors provided by auto-deploy are stable.
316+ wrapper = self ._init_decode_wrapper (
317+ use_cuda_graph = True ,
318+ indptr = kv_page_indptr ,
319+ indices = kv_page_indices ,
320+ last_page_len = kv_last_page_len ,
321+ )
322+ self .cached_cuda_graph_decode_wrappers [plan_params ] = wrapper
323+ _plan_decode (self .cached_cuda_graph_decode_wrappers [plan_params ])
118324 # check if we are in cuda graph capture and just return the pre-cached decode wrapper
119325 if torch .cuda .is_current_stream_capturing () or cuda_graph_state .in_warm_up ():
120326 assert plan_params .is_generate , "Only generate is supported during cuda graph capture."
121- wrapper = self .cached_decode_wrappers [plan_params ]
122- # copy the metadata to the wrapper to ensure it is up-to-date for graph replay!
123- wrapper ._paged_kv_indptr_buf .copy_ (kv_page_indptr )
124- wrapper ._paged_kv_indices_buf .copy_ (kv_page_indices )
125- wrapper ._paged_kv_last_page_len_buf .copy_ (kv_last_page_len )
327+ wrapper = self .cached_cuda_graph_decode_wrappers [plan_params ]
126328 return wrapper
127329
128330 # check for re-planning
@@ -167,14 +369,13 @@ def prepare_flashinfer_metadata(
167369 https://docs.flashinfer.ai/api/prefill.html#flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper.plan
168370 to understand the convention.
169371 """
170- # reset the planner
171- _GlobalFlashInferPlanner .reset ()
172-
173372 # retrieve host-side metadata
174373 num_prefill , num_prefill_tokens , num_decode = batch_info .tolist ()
175374 num_seq = num_prefill + num_decode
176375 num_tokens = num_prefill_tokens + num_decode
177376
377+ _GlobalFlashInferPlanner .reset ()
378+
178379 qo_indptr = cu_seqlen [: num_seq + 1 ]
179380
180381 # NOTE: in theory we could easily precompute batch_indices. And positions is just position_ids
@@ -398,6 +599,20 @@ def _init_workspace(si: SequenceInfo) -> torch.Tensor:
398599
399600 return {"workspace_buffer" : _init_workspace }
400601
602+ @classmethod
603+ def host_prepare_for_forward (cls , sequence_info : SequenceInfo ):
604+ batch_info = sequence_info ._input_buffer .get_host_view ("batch_info" )
605+ num_prefill , num_prefill_tokens , num_decode = batch_info .tolist ()
606+ # Call plan for generate-only batches.
607+ if num_prefill == 0 :
608+ _GlobalFlashInferPlanner .plan_generate_only (
609+ num_decode ,
610+ sequence_info ._input_buffer .get_host_view ("cu_num_pages" )[: num_decode + 1 ],
611+ sequence_info ._input_buffer .get_host_view ("cache_loc" ),
612+ sequence_info ._input_buffer .get_host_view ("last_page_len" )[:num_decode ],
613+ )
614+ return
615+
401616 @classmethod
402617 def get_constants (cls , source_attn_node : Node ) -> List [Constant ]:
403618 # Sanity check: layout == "bsnd"
0 commit comments