Skip to content

Commit 8c98841

Browse files
committed
separate prefill/decode in flashinfer
Signed-off-by: Lucas Liebenwein <[email protected]>
1 parent 16e148d commit 8c98841

File tree

3 files changed

+250
-78
lines changed

3 files changed

+250
-78
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py

Lines changed: 153 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -205,14 +205,16 @@ class _FlashInferPlanner:
205205
cached_cuda_graph_decode_wrappers: Dict[
206206
PlanParams, flashinfer.BatchDecodeWithPagedKVCacheWrapper
207207
]
208-
plan_params: Optional[PlanParams]
208+
plan_params_prefill: Optional[PlanParams]
209+
plan_params_decode: Optional[PlanParams]
209210

210211
def __init__(self):
211212
self.workspace_buffer = None
212213
self.prefill_wrapper = None
213214
self.decode_wrapper = None
214215
self.cached_cuda_graph_decode_wrappers = {}
215-
self.plan_params = None
216+
self.plan_params_prefill = None
217+
self.plan_params_decode = None
216218

217219
def _init_decode_wrapper(
218220
self,
@@ -253,7 +255,8 @@ def init_workspace(self, workspace_buffer: torch.Tensor):
253255
self.decode_wrapper = self._init_decode_wrapper()
254256

255257
def reset(self) -> None:
256-
self.plan_params = None
258+
self.plan_params_prefill = None
259+
self.plan_params_decode = None
257260

258261
def plan_generate_only(
259262
self,
@@ -279,9 +282,42 @@ def plan_generate_only(
279282
sm_scale=plan_params.sm_scale,
280283
)
281284

282-
def plan(
285+
def plan_prefill(
286+
self,
287+
qo_indptr_host: torch.Tensor,
288+
kv_page_indptr_host: torch.Tensor,
289+
kv_page_indices: torch.Tensor,
290+
kv_last_page_len_host: torch.Tensor,
291+
kv_lens_arr_host: torch.Tensor,
292+
seq_len_host: torch.Tensor,
293+
plan_params: PlanParams,
294+
) -> None:
295+
# check for re-planning
296+
if plan_params != self.plan_params_prefill:
297+
# plan prefill
298+
self.prefill_wrapper.plan(
299+
qo_indptr_host,
300+
kv_page_indptr_host,
301+
kv_page_indices,
302+
kv_last_page_len_host,
303+
plan_params.n_heads, # Q heads
304+
plan_params.n_kv_heads, # KV heads
305+
plan_params.head_dim,
306+
plan_params.page_size,
307+
causal=plan_params.causal,
308+
q_data_type=plan_params.q_dtype,
309+
kv_data_type=plan_params.kv_dtype,
310+
sm_scale=plan_params.sm_scale,
311+
# max_token_per_sequence=max(seq_len_host).item(),
312+
seq_lens=kv_lens_arr_host,
313+
)
314+
self.plan_params_prefill = plan_params
315+
316+
# return prefill wrapper
317+
return self.prefill_wrapper
318+
319+
def plan_decode(
283320
self,
284-
qo_indptr: torch.Tensor,
285321
kv_page_indptr: torch.Tensor,
286322
kv_page_indices: torch.Tensor,
287323
kv_last_page_len: torch.Tensor,
@@ -328,29 +364,12 @@ def _plan_decode(
328364
return wrapper
329365

330366
# check for re-planning
331-
if plan_params != self.plan_params:
332-
if plan_params.is_generate:
333-
_plan_decode(self.decode_wrapper)
334-
else:
335-
# plan prefill
336-
self.prefill_wrapper.plan(
337-
qo_indptr,
338-
kv_page_indptr,
339-
kv_page_indices,
340-
kv_last_page_len,
341-
plan_params.n_heads, # Q heads
342-
plan_params.n_kv_heads, # KV heads
343-
plan_params.head_dim,
344-
plan_params.page_size,
345-
causal=plan_params.causal,
346-
q_data_type=plan_params.q_dtype,
347-
kv_data_type=plan_params.kv_dtype,
348-
sm_scale=plan_params.sm_scale,
349-
)
350-
self.plan_params = plan_params
367+
if plan_params != self.plan_params_decode:
368+
_plan_decode(self.decode_wrapper)
369+
self.plan_params_decode = plan_params
351370

352-
# return desired wrapper
353-
return self.decode_wrapper if plan_params.is_generate else self.prefill_wrapper
371+
# return decode wrapper
372+
return self.decode_wrapper
354373

355374

356375
_GlobalFlashInferPlanner = _FlashInferPlanner()
@@ -412,10 +431,14 @@ def flashinfer_mha_with_cache(
412431
v: torch.Tensor,
413432
# STANDARD METADATA
414433
batch_info_host: torch.Tensor,
415-
cu_seqlen: torch.Tensor,
434+
cu_seqlen_host: torch.Tensor,
416435
cu_num_pages: torch.Tensor,
436+
cu_num_pages_host: torch.Tensor,
417437
cache_loc: torch.Tensor,
418438
last_page_len: torch.Tensor,
439+
last_page_len_host: torch.Tensor,
440+
seq_len_with_cache_host: torch.Tensor,
441+
seq_len_host: torch.Tensor,
419442
# EXTRA METADATA
420443
flashinfer_batch_indices: torch.Tensor,
421444
flashinfer_positions: torch.Tensor,
@@ -441,30 +464,11 @@ def flashinfer_mha_with_cache(
441464
# convert to flashinfer-style metadata
442465
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
443466
num_seq = num_prefill + num_decode
444-
445-
qo_indptr = cu_seqlen[: num_seq + 1]
446-
paged_kv_indptr = cu_num_pages[: num_seq + 1]
447-
448-
# NOTE: it is okay to have cache_loc here without truncation. paged_kv_indptr will be
449-
# truncated and will point to the correct sub range of cache_loc.
450-
paged_kv_indices = cache_loc
451-
paged_kv_last_page_len = last_page_len[:num_seq]
467+
num_total_tokens = num_prefill_tokens + num_decode
452468

453469
n_heads = q.shape[1]
454470
n_kv_heads = k.shape[1]
455471

456-
pp = PlanParams(
457-
n_heads=n_heads,
458-
n_kv_heads=n_kv_heads,
459-
head_dim=head_dim,
460-
num_seq=len(qo_indptr) - 1,
461-
is_generate=(s == 1),
462-
page_size=k_cache.shape[1],
463-
q_dtype=q.dtype,
464-
kv_dtype=k_cache.dtype,
465-
sm_scale=scale,
466-
)
467-
468472
# Assuming k_scale = v_scale = 1.0
469473
k_scale, v_scale = 1.0, 1.0
470474
# k = (k / k_scale).to(torch.float8_e4m3fn) if k_scale != 1.0, same for v
@@ -473,28 +477,94 @@ def flashinfer_mha_with_cache(
473477
v = v.to(torch.float8_e4m3fn)
474478

475479
flashinfer.page.append_paged_kv_cache(
476-
k,
477-
v,
478-
flashinfer_batch_indices,
479-
flashinfer_positions,
480-
(k_cache, v_cache),
481-
paged_kv_indices,
482-
paged_kv_indptr,
483-
paged_kv_last_page_len,
480+
append_key=k,
481+
append_value=v,
482+
batch_indices=flashinfer_batch_indices,
483+
positions=flashinfer_positions,
484+
paged_kv_cache=(k_cache, v_cache),
485+
kv_indices=cache_loc,
486+
kv_indptr=cu_num_pages[: num_seq + 1],
487+
kv_last_page_len=last_page_len[:num_seq],
484488
)
485489

486-
# run the flashinfer planner and obtain the correct wrapper
487-
wrapper = _GlobalFlashInferPlanner.plan(
488-
qo_indptr,
489-
paged_kv_indptr,
490-
paged_kv_indices,
491-
paged_kv_last_page_len,
492-
pp,
493-
)
490+
# check if we need to re-combine outputs
491+
if num_prefill > 0 and num_decode > 0:
492+
y = torch.empty_like(q)
493+
else:
494+
y = None
495+
496+
# now run split prefill, decode
497+
if num_prefill > 0:
498+
q_prefill = q[:num_prefill_tokens]
499+
500+
pp_prefill = PlanParams(
501+
n_heads=n_heads,
502+
n_kv_heads=n_kv_heads,
503+
head_dim=head_dim,
504+
num_seq=num_prefill,
505+
is_generate=False,
506+
page_size=k_cache.shape[1],
507+
q_dtype=q_prefill.dtype,
508+
kv_dtype=k_cache.dtype,
509+
sm_scale=scale,
510+
)
494511

495-
y = wrapper.run(
496-
q, (k_cache, v_cache), k_scale=k_scale, v_scale=v_scale, enable_pdl=get_env_enable_pdl()
497-
)
512+
wrapper_prefill = _GlobalFlashInferPlanner.plan_prefill(
513+
qo_indptr_host=cu_seqlen_host[: num_prefill + 1],
514+
kv_page_indptr_host=cu_num_pages_host[: num_prefill + 1],
515+
kv_page_indices=cache_loc,
516+
kv_last_page_len_host=last_page_len_host[:num_prefill],
517+
kv_lens_arr_host=seq_len_with_cache_host[:num_prefill],
518+
seq_len_host=seq_len_host[:num_prefill],
519+
plan_params=pp_prefill,
520+
)
521+
522+
y_prefill = wrapper_prefill.run(
523+
q_prefill,
524+
(k_cache, v_cache),
525+
k_scale=k_scale,
526+
v_scale=v_scale,
527+
enable_pdl=get_env_enable_pdl(),
528+
)
529+
if y is not None:
530+
y[:num_prefill_tokens] = y_prefill
531+
else:
532+
y = y_prefill
533+
534+
if num_decode > 0:
535+
q_decode = q[num_prefill_tokens:num_total_tokens]
536+
537+
pp_decode = PlanParams(
538+
n_heads=n_heads,
539+
n_kv_heads=n_kv_heads,
540+
head_dim=head_dim,
541+
num_seq=num_decode,
542+
is_generate=True,
543+
page_size=k_cache.shape[1],
544+
q_dtype=q_decode.dtype,
545+
kv_dtype=k_cache.dtype,
546+
sm_scale=scale,
547+
)
548+
549+
# run the flashinfer planner and obtain the correct wrapper
550+
wrapper_decode = _GlobalFlashInferPlanner.plan_decode(
551+
kv_page_indptr=cu_num_pages[num_prefill : num_seq + 1],
552+
kv_page_indices=cache_loc,
553+
kv_last_page_len=last_page_len[num_prefill:num_seq],
554+
plan_params=pp_decode,
555+
)
556+
557+
y_decode = wrapper_decode.run(
558+
q_decode,
559+
(k_cache, v_cache),
560+
k_scale=k_scale,
561+
v_scale=v_scale,
562+
enable_pdl=get_env_enable_pdl(),
563+
)
564+
if y is not None:
565+
y[num_prefill_tokens:num_total_tokens] = y_decode
566+
else:
567+
y = y_decode
498568

499569
return y.view(q_shape_og) # [b,s,n*h_d] or [b,s, n, h_d]
500570

@@ -507,10 +577,14 @@ def flashinfer_mha_with_cache_fake(
507577
v: torch.Tensor,
508578
# STANDARD METADATA
509579
batch_info_host: torch.Tensor,
510-
cu_seqlen: torch.Tensor,
580+
cu_seqlen_host: torch.Tensor,
511581
cu_num_pages: torch.Tensor,
582+
cu_num_pages_host: torch.Tensor,
512583
cache_loc: torch.Tensor,
513584
last_page_len: torch.Tensor,
585+
last_page_len_host: torch.Tensor,
586+
seq_len_with_cache_host: torch.Tensor,
587+
seq_len_host: torch.Tensor,
514588
# EXTRA METADATA
515589
flashinfer_batch_indices: torch.Tensor,
516590
flashinfer_positions: torch.Tensor,
@@ -559,7 +633,17 @@ def get_cached_attention_op(cls) -> MHACallable:
559633

560634
@classmethod
561635
def get_standard_metadata_args(cls) -> List[str]:
562-
return ["batch_info_host", "cu_seqlen", "cu_num_pages", "cache_loc", "last_page_len"]
636+
return [
637+
"batch_info_host",
638+
"cu_seqlen_host",
639+
"cu_num_pages",
640+
"cu_num_pages_host",
641+
"cache_loc",
642+
"last_page_len",
643+
"last_page_len_host",
644+
"seq_len_with_cache_host",
645+
"seq_len_host",
646+
]
563647

564648
@classmethod
565649
def get_prepare_extra_metadata_info(

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def _call_func():
379379

380380
# check if we have a dummy request to use
381381
if self.padding_dummy_request is None:
382-
ad_logger.error("No CUDA graph padding possible due to missing dummy request.")
382+
ad_logger.info("No CUDA graph padding possible due to missing dummy request.")
383383
return _call_func()
384384

385385
# pad the scheduled requests with the dummy request

0 commit comments

Comments
 (0)