Skip to content

Commit 746b71f

Browse files
committed
separate prefill/decode in flashinfer
Signed-off-by: Lucas Liebenwein <[email protected]>
1 parent f8ff684 commit 746b71f

File tree

3 files changed

+251
-78
lines changed

3 files changed

+251
-78
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py

Lines changed: 154 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,16 @@ class _FlashInferPlanner:
5454
decode_wrapper: Optional[flashinfer.BatchDecodeWithPagedKVCacheWrapper]
5555
cached_decode_wrappers: Dict[PlanParams, flashinfer.BatchDecodeWithPagedKVCacheWrapper]
5656
plan_params: Optional[PlanParams]
57+
plan_params_prefill: Optional[PlanParams]
58+
plan_params_decode: Optional[PlanParams]
5759

5860
def __init__(self):
5961
self.workspace_buffer = None
6062
self.prefill_wrapper = None
6163
self.decode_wrapper = None
6264
self.cached_decode_wrappers = {}
63-
self.plan_params = None
65+
self.plan_params_prefill = None
66+
self.plan_params_decode = None
6467

6568
def _init_decode_wrapper(self):
6669
assert self.workspace_buffer is not None
@@ -82,11 +85,45 @@ def init_workspace(self, workspace_buffer: torch.Tensor):
8285
self.decode_wrapper = self._init_decode_wrapper()
8386

8487
def reset(self) -> None:
85-
self.plan_params = None
88+
self.plan_params_prefill = None
89+
self.plan_params_decode = None
8690

87-
def plan(
91+
def plan_prefill(
92+
self,
93+
qo_indptr_host: torch.Tensor,
94+
kv_page_indptr_host: torch.Tensor,
95+
kv_page_indices: torch.Tensor,
96+
kv_last_page_len_host: torch.Tensor,
97+
kv_lens_arr_host: torch.Tensor,
98+
seq_len_host: torch.Tensor,
99+
plan_params: PlanParams,
100+
) -> None:
101+
# check for re-planning
102+
if plan_params != self.plan_params_prefill:
103+
# plan prefill
104+
self.prefill_wrapper.plan(
105+
qo_indptr_host,
106+
kv_page_indptr_host,
107+
kv_page_indices,
108+
kv_last_page_len_host,
109+
plan_params.n_heads, # Q heads
110+
plan_params.n_kv_heads, # KV heads
111+
plan_params.head_dim,
112+
plan_params.page_size,
113+
causal=plan_params.causal,
114+
q_data_type=plan_params.q_dtype,
115+
kv_data_type=plan_params.kv_dtype,
116+
sm_scale=plan_params.sm_scale,
117+
# max_token_per_sequence=max(seq_len_host).item(),
118+
seq_lens=kv_lens_arr_host,
119+
)
120+
self.plan_params_prefill = plan_params
121+
122+
# return prefill wrapper
123+
return self.prefill_wrapper
124+
125+
def plan_decode(
88126
self,
89-
qo_indptr: torch.Tensor,
90127
kv_page_indptr: torch.Tensor,
91128
kv_page_indices: torch.Tensor,
92129
kv_last_page_len: torch.Tensor,
@@ -126,29 +163,12 @@ def _plan_decode(wrapper: flashinfer.BatchDecodeWithPagedKVCacheWrapper):
126163
return wrapper
127164

128165
# check for re-planning
129-
if plan_params != self.plan_params:
130-
if plan_params.is_generate:
131-
_plan_decode(self.decode_wrapper)
132-
else:
133-
# plan prefill
134-
self.prefill_wrapper.plan(
135-
qo_indptr,
136-
kv_page_indptr,
137-
kv_page_indices,
138-
kv_last_page_len,
139-
plan_params.n_heads, # Q heads
140-
plan_params.n_kv_heads, # KV heads
141-
plan_params.head_dim,
142-
plan_params.page_size,
143-
causal=plan_params.causal,
144-
q_data_type=plan_params.q_dtype,
145-
kv_data_type=plan_params.kv_dtype,
146-
sm_scale=plan_params.sm_scale,
147-
)
148-
self.plan_params = plan_params
149-
150-
# return desired wrapper
151-
return self.decode_wrapper if plan_params.is_generate else self.prefill_wrapper
166+
if plan_params != self.plan_params_decode:
167+
_plan_decode(self.decode_wrapper)
168+
self.plan_params_decode = plan_params
169+
170+
# return decode wrapper
171+
return self.decode_wrapper
152172

153173

154174
_GlobalFlashInferPlanner = _FlashInferPlanner()
@@ -211,10 +231,14 @@ def flashinfer_mha_with_cache(
211231
v: torch.Tensor,
212232
# STANDARD METADATA
213233
batch_info_host: torch.Tensor,
214-
cu_seqlen: torch.Tensor,
234+
cu_seqlen_host: torch.Tensor,
215235
cu_num_pages: torch.Tensor,
236+
cu_num_pages_host: torch.Tensor,
216237
cache_loc: torch.Tensor,
217238
last_page_len: torch.Tensor,
239+
last_page_len_host: torch.Tensor,
240+
seq_len_with_cache_host: torch.Tensor,
241+
seq_len_host: torch.Tensor,
218242
# EXTRA METADATA
219243
flashinfer_batch_indices: torch.Tensor,
220244
flashinfer_positions: torch.Tensor,
@@ -240,30 +264,11 @@ def flashinfer_mha_with_cache(
240264
# convert to flashinfer-style metadata
241265
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
242266
num_seq = num_prefill + num_decode
243-
244-
qo_indptr = cu_seqlen[: num_seq + 1]
245-
paged_kv_indptr = cu_num_pages[: num_seq + 1]
246-
247-
# NOTE: it is okay to have cache_loc here without truncation. paged_kv_indptr will be
248-
# truncated and will point to the correct sub range of cache_loc.
249-
paged_kv_indices = cache_loc
250-
paged_kv_last_page_len = last_page_len[:num_seq]
267+
num_total_tokens = num_prefill_tokens + num_decode
251268

252269
n_heads = q.shape[1]
253270
n_kv_heads = k.shape[1]
254271

255-
pp = PlanParams(
256-
n_heads=n_heads,
257-
n_kv_heads=n_kv_heads,
258-
head_dim=head_dim,
259-
num_seq=len(qo_indptr) - 1,
260-
is_generate=(s == 1),
261-
page_size=k_cache.shape[1],
262-
q_dtype=q.dtype,
263-
kv_dtype=k_cache.dtype,
264-
sm_scale=scale,
265-
)
266-
267272
# Assuming k_scale = v_scale = 1.0
268273
k_scale, v_scale = 1.0, 1.0
269274
# k = (k / k_scale).to(torch.float8_e4m3fn) if k_scale != 1.0, same for v
@@ -272,28 +277,94 @@ def flashinfer_mha_with_cache(
272277
v = v.to(torch.float8_e4m3fn)
273278

274279
flashinfer.page.append_paged_kv_cache(
275-
k,
276-
v,
277-
flashinfer_batch_indices,
278-
flashinfer_positions,
279-
(k_cache, v_cache),
280-
paged_kv_indices,
281-
paged_kv_indptr,
282-
paged_kv_last_page_len,
280+
append_key=k,
281+
append_value=v,
282+
batch_indices=flashinfer_batch_indices,
283+
positions=flashinfer_positions,
284+
paged_kv_cache=(k_cache, v_cache),
285+
kv_indices=cache_loc,
286+
kv_indptr=cu_num_pages[: num_seq + 1],
287+
kv_last_page_len=last_page_len[:num_seq],
283288
)
284289

285-
# run the flashinfer planner and obtain the correct wrapper
286-
wrapper = _GlobalFlashInferPlanner.plan(
287-
qo_indptr,
288-
paged_kv_indptr,
289-
paged_kv_indices,
290-
paged_kv_last_page_len,
291-
pp,
292-
)
290+
# check if we need to re-combine outputs
291+
if num_prefill > 0 and num_decode > 0:
292+
y = torch.empty_like(q)
293+
else:
294+
y = None
295+
296+
# now run split prefill, decode
297+
if num_prefill > 0:
298+
q_prefill = q[:num_prefill_tokens]
299+
300+
pp_prefill = PlanParams(
301+
n_heads=n_heads,
302+
n_kv_heads=n_kv_heads,
303+
head_dim=head_dim,
304+
num_seq=num_prefill,
305+
is_generate=False,
306+
page_size=k_cache.shape[1],
307+
q_dtype=q_prefill.dtype,
308+
kv_dtype=k_cache.dtype,
309+
sm_scale=scale,
310+
)
293311

294-
y = wrapper.run(
295-
q, (k_cache, v_cache), k_scale=k_scale, v_scale=v_scale, enable_pdl=get_env_enable_pdl()
296-
)
312+
wrapper_prefill = _GlobalFlashInferPlanner.plan_prefill(
313+
qo_indptr_host=cu_seqlen_host[: num_prefill + 1],
314+
kv_page_indptr_host=cu_num_pages_host[: num_prefill + 1],
315+
kv_page_indices=cache_loc,
316+
kv_last_page_len_host=last_page_len_host[:num_prefill],
317+
kv_lens_arr_host=seq_len_with_cache_host[:num_prefill],
318+
seq_len_host=seq_len_host[:num_prefill],
319+
plan_params=pp_prefill,
320+
)
321+
322+
y_prefill = wrapper_prefill.run(
323+
q_prefill,
324+
(k_cache, v_cache),
325+
k_scale=k_scale,
326+
v_scale=v_scale,
327+
enable_pdl=get_env_enable_pdl(),
328+
)
329+
if y is not None:
330+
y[:num_prefill_tokens] = y_prefill
331+
else:
332+
y = y_prefill
333+
334+
if num_decode > 0:
335+
q_decode = q[num_prefill_tokens:num_total_tokens]
336+
337+
pp_decode = PlanParams(
338+
n_heads=n_heads,
339+
n_kv_heads=n_kv_heads,
340+
head_dim=head_dim,
341+
num_seq=num_decode,
342+
is_generate=True,
343+
page_size=k_cache.shape[1],
344+
q_dtype=q_decode.dtype,
345+
kv_dtype=k_cache.dtype,
346+
sm_scale=scale,
347+
)
348+
349+
# run the flashinfer planner and obtain the correct wrapper
350+
wrapper_decode = _GlobalFlashInferPlanner.plan_decode(
351+
kv_page_indptr=cu_num_pages[num_prefill : num_seq + 1],
352+
kv_page_indices=cache_loc,
353+
kv_last_page_len=last_page_len[num_prefill:num_seq],
354+
plan_params=pp_decode,
355+
)
356+
357+
y_decode = wrapper_decode.run(
358+
q_decode,
359+
(k_cache, v_cache),
360+
k_scale=k_scale,
361+
v_scale=v_scale,
362+
enable_pdl=get_env_enable_pdl(),
363+
)
364+
if y is not None:
365+
y[num_prefill_tokens:num_total_tokens] = y_decode
366+
else:
367+
y = y_decode
297368

298369
return y.view(q_shape_og) # [b,s,n*h_d] or [b,s, n, h_d]
299370

@@ -306,10 +377,14 @@ def flashinfer_mha_with_cache_fake(
306377
v: torch.Tensor,
307378
# STANDARD METADATA
308379
batch_info_host: torch.Tensor,
309-
cu_seqlen: torch.Tensor,
380+
cu_seqlen_host: torch.Tensor,
310381
cu_num_pages: torch.Tensor,
382+
cu_num_pages_host: torch.Tensor,
311383
cache_loc: torch.Tensor,
312384
last_page_len: torch.Tensor,
385+
last_page_len_host: torch.Tensor,
386+
seq_len_with_cache_host: torch.Tensor,
387+
seq_len_host: torch.Tensor,
313388
# EXTRA METADATA
314389
flashinfer_batch_indices: torch.Tensor,
315390
flashinfer_positions: torch.Tensor,
@@ -358,7 +433,17 @@ def get_cached_attention_op(cls) -> MHACallable:
358433

359434
@classmethod
360435
def get_standard_metadata_args(cls) -> List[str]:
361-
return ["batch_info_host", "cu_seqlen", "cu_num_pages", "cache_loc", "last_page_len"]
436+
return [
437+
"batch_info_host",
438+
"cu_seqlen_host",
439+
"cu_num_pages",
440+
"cu_num_pages_host",
441+
"cache_loc",
442+
"last_page_len",
443+
"last_page_len_host",
444+
"seq_len_with_cache_host",
445+
"seq_len_host",
446+
]
362447

363448
@classmethod
364449
def get_prepare_extra_metadata_info(

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def _call_func():
289289

290290
# check if we have a dummy request to use
291291
if self.padding_dummy_request is None:
292-
ad_logger.error("No CUDA graph padding possible due to missing dummy request.")
292+
ad_logger.info("No CUDA graph padding possible due to missing dummy request.")
293293
return _call_func()
294294

295295
# pad the scheduled requests with the dummy request

0 commit comments

Comments
 (0)