Skip to content

Commit e2891a6

Browse files
[#10052][feat] AutoDeploy enable cudagraphs for flashinfer BatchDecode (#10193)
Signed-off-by: Chenghao Zhang <[email protected]> Signed-off-by: Suyog Gupta <[email protected]> Co-authored-by: Chenghao Zhang <[email protected]>
1 parent ddac4d7 commit e2891a6

File tree

4 files changed

+275
-20
lines changed

4 files changed

+275
-20
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,19 @@
1010
"""
1111

1212
from abc import ABC, abstractmethod
13-
from typing import Dict, List, Literal, Optional, Protocol, Sequence, Set, Tuple, Type, Union
13+
from typing import (
14+
Callable,
15+
Dict,
16+
List,
17+
Literal,
18+
Optional,
19+
Protocol,
20+
Sequence,
21+
Set,
22+
Tuple,
23+
Type,
24+
Union,
25+
)
1426

1527
import torch
1628
from pydantic import BaseModel, ConfigDict, Field, field_validator
@@ -512,6 +524,9 @@ def __init__(
512524
self._extra_args: Dict[str, Optional[torch.Tensor]] = {}
513525
############################################################################################
514526

527+
# HOST PREPARE FOR ATTENTION FORWARD #######################################################
528+
self._host_prepare_functions: set[Callable[[SequenceInfo], None]] = set()
529+
515530
# call reset once to set a consistent initial state
516531
self.reset()
517532

@@ -1089,6 +1104,15 @@ def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]:
10891104
t_squeezed = t_nested.squeeze(int(self.is_generate))
10901105
return list(torch.split(t_squeezed, self.seq_len))
10911106

1107+
def register_host_prepare_for_attention_forward(
1108+
self, host_function: Callable[["SequenceInfo"], None]
1109+
):
1110+
self._host_prepare_functions.add(host_function)
1111+
1112+
def run_host_prepare_for_attention_forward(self) -> None:
1113+
for host_function in self._host_prepare_functions:
1114+
host_function(self)
1115+
10921116

10931117
class MHACallable(Protocol):
10941118
def __call__(
@@ -1266,6 +1290,15 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]:
12661290
"""
12671291
return []
12681292

1293+
@classmethod
1294+
def host_prepare_for_forward(cls, sequence_info: SequenceInfo):
1295+
"""Perform host-side preparation for the forward pass for the attention op.
1296+
1297+
This method is responsible for preparing the attention op for the forward pass.
1298+
This function is not expected to be graph capturable or compatible with cuda graphs.
1299+
"""
1300+
return
1301+
12691302

12701303
class AttentionRegistry:
12711304
"""A simple registry to look up different attention implementations."""

tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py

Lines changed: 234 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,156 @@
2525
)
2626

2727

28+
# TODO: remove this when flashinfer version is updated to >0.5
29+
def fast_decode_plan(
30+
wrapper: flashinfer.BatchDecodeWithPagedKVCacheWrapper,
31+
indptr: torch.Tensor,
32+
indices: torch.Tensor,
33+
last_page_len: torch.Tensor,
34+
num_qo_heads: int,
35+
num_kv_heads: int,
36+
head_dim: int,
37+
page_size: int,
38+
pos_encoding_mode: str = "NONE",
39+
window_left: int = -1,
40+
logits_soft_cap: Optional[float] = None,
41+
q_data_type: Optional[Union[str, torch.dtype]] = None,
42+
kv_data_type: Optional[Union[str, torch.dtype]] = None,
43+
data_type: Optional[Union[str, torch.dtype]] = None,
44+
sm_scale: Optional[float] = None,
45+
rope_scale: Optional[float] = None,
46+
rope_theta: Optional[float] = None,
47+
non_blocking: bool = True,
48+
fixed_split_size: Optional[int] = None,
49+
disable_split_kv: bool = False,
50+
global_override_indptr_cpu: Optional[torch.Tensor] = None,
51+
) -> None:
52+
"""
53+
Copied from flashinfer.decode.fast_decode_plan in flashinfer version >0.5.
54+
Does not exist in flashinfer version 0.3.1, hence copied here.
55+
"""
56+
batch_size = len(last_page_len)
57+
if logits_soft_cap is None:
58+
logits_soft_cap = 0.0
59+
60+
# Handle data types consistently
61+
if data_type is not None:
62+
if q_data_type is None:
63+
q_data_type = data_type
64+
if kv_data_type is None:
65+
kv_data_type = data_type
66+
elif q_data_type is None:
67+
q_data_type = "float16"
68+
69+
if kv_data_type is None:
70+
kv_data_type = q_data_type
71+
72+
if wrapper.use_tensor_cores:
73+
qo_indptr_host = torch.arange(batch_size + 1, dtype=torch.int32, device="cpu")
74+
# Here we set fixed_split_size to -1 to avoid the assertion error in flashinfer's plan function
75+
if fixed_split_size is None:
76+
fixed_split_size = -1
77+
78+
if wrapper.is_cuda_graph_enabled:
79+
if batch_size != wrapper._fixed_batch_size:
80+
raise ValueError(
81+
"The batch size should be fixed in cudagraph mode, the runtime batch size {} "
82+
" mismatches the batch size set during initialization {}".format(
83+
batch_size, wrapper._fixed_batch_size
84+
)
85+
)
86+
if len(indices) > len(wrapper._paged_kv_indices_buf):
87+
raise ValueError(
88+
"The size of indices should be less than or equal to the allocated buffer"
89+
)
90+
else:
91+
wrapper._paged_kv_indptr_buf = indptr
92+
wrapper._paged_kv_indices_buf = indices
93+
wrapper._paged_kv_last_page_len_buf = last_page_len
94+
if wrapper.use_tensor_cores:
95+
wrapper._qo_indptr_buf = qo_indptr_host.to(wrapper.device, non_blocking=non_blocking)
96+
97+
# Create empty tensors for dtype info if needed
98+
empty_q_data = torch.empty(
99+
0,
100+
dtype=(getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type),
101+
device=wrapper.device,
102+
)
103+
104+
empty_kv_cache = torch.empty(
105+
0,
106+
dtype=(getattr(torch, kv_data_type) if isinstance(kv_data_type, str) else kv_data_type),
107+
device=wrapper.device,
108+
)
109+
110+
indptr_host = (
111+
global_override_indptr_cpu if global_override_indptr_cpu is not None else indptr.cpu()
112+
)
113+
114+
with torch.cuda.device(wrapper.device):
115+
if wrapper.use_tensor_cores:
116+
# ALSO convert last_page_len to CPU
117+
if page_size == 1:
118+
# When page size is 1, last_page_len is always 1.
119+
# Directly construct the host tensor rather than executing a device-to-host copy.
120+
last_page_len_host = torch.ones((batch_size,), dtype=torch.int32, device="cpu")
121+
else:
122+
last_page_len_host = last_page_len.cpu()
123+
124+
kv_lens_arr_host = flashinfer.get_seq_lens(indptr_host, last_page_len_host, page_size)
125+
126+
try:
127+
# Make sure we pass exactly 15 arguments for tensor core version
128+
wrapper._plan_info = wrapper._cached_module.plan(
129+
wrapper._float_workspace_buffer,
130+
wrapper._int_workspace_buffer,
131+
wrapper._pin_memory_int_workspace_buffer,
132+
qo_indptr_host,
133+
indptr_host,
134+
kv_lens_arr_host,
135+
batch_size, # total_num_rows
136+
batch_size,
137+
num_qo_heads,
138+
num_kv_heads,
139+
page_size,
140+
wrapper.is_cuda_graph_enabled,
141+
head_dim,
142+
head_dim,
143+
False, # causal
144+
)
145+
except Exception as e:
146+
raise RuntimeError(f"Error in standard plan: {e}") from e
147+
else:
148+
try:
149+
# Make sure we pass exactly 15 arguments for standard version
150+
wrapper._plan_info = wrapper._cached_module.plan(
151+
wrapper._float_workspace_buffer,
152+
wrapper._int_workspace_buffer,
153+
wrapper._pin_memory_int_workspace_buffer,
154+
indptr_host,
155+
batch_size,
156+
num_qo_heads,
157+
num_kv_heads,
158+
page_size,
159+
wrapper.is_cuda_graph_enabled,
160+
window_left,
161+
logits_soft_cap,
162+
head_dim,
163+
head_dim,
164+
empty_q_data,
165+
empty_kv_cache,
166+
)
167+
except Exception as e:
168+
raise RuntimeError(f"Error in standard plan: {e}") from e
169+
170+
wrapper._pos_encoding_mode = pos_encoding_mode
171+
wrapper._window_left = window_left
172+
wrapper._logits_soft_cap = logits_soft_cap
173+
wrapper._sm_scale = sm_scale
174+
wrapper._rope_scale = rope_scale
175+
wrapper._rope_theta = rope_theta
176+
177+
28178
@dataclass
29179
class PlanParams:
30180
"""Parameters that affect the flashinfer execution plan."""
@@ -52,21 +202,42 @@ class _FlashInferPlanner:
52202
workspace_buffer: Optional[torch.Tensor]
53203
prefill_wrapper: Optional[flashinfer.BatchPrefillWithPagedKVCacheWrapper]
54204
decode_wrapper: Optional[flashinfer.BatchDecodeWithPagedKVCacheWrapper]
55-
cached_decode_wrappers: Dict[PlanParams, flashinfer.BatchDecodeWithPagedKVCacheWrapper]
205+
cached_cuda_graph_decode_wrappers: Dict[
206+
PlanParams, flashinfer.BatchDecodeWithPagedKVCacheWrapper
207+
]
56208
plan_params: Optional[PlanParams]
57209

58210
def __init__(self):
59211
self.workspace_buffer = None
60212
self.prefill_wrapper = None
61213
self.decode_wrapper = None
62-
self.cached_decode_wrappers = {}
214+
self.cached_cuda_graph_decode_wrappers = {}
63215
self.plan_params = None
64216

65-
def _init_decode_wrapper(self):
217+
def _init_decode_wrapper(
218+
self,
219+
use_cuda_graph: bool = False,
220+
indptr: Optional[torch.Tensor] = None,
221+
indices: Optional[torch.Tensor] = None,
222+
last_page_len: Optional[torch.Tensor] = None,
223+
):
66224
assert self.workspace_buffer is not None
67-
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
68-
self.workspace_buffer, "NHD", use_tensor_cores=True
69-
)
225+
if use_cuda_graph:
226+
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
227+
self.workspace_buffer,
228+
"NHD",
229+
use_cuda_graph=True,
230+
paged_kv_indptr_buffer=indptr,
231+
paged_kv_indices_buffer=indices,
232+
paged_kv_last_page_len_buffer=last_page_len,
233+
use_tensor_cores=True,
234+
)
235+
else:
236+
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
237+
self.workspace_buffer,
238+
"NHD",
239+
use_tensor_cores=True,
240+
)
70241

71242
def init_workspace(self, workspace_buffer: torch.Tensor):
72243
self.__init__() # reset all state
@@ -84,6 +255,30 @@ def init_workspace(self, workspace_buffer: torch.Tensor):
84255
def reset(self) -> None:
85256
self.plan_params = None
86257

258+
def plan_generate_only(
259+
self,
260+
num_seq: int,
261+
cu_num_pages: torch.Tensor,
262+
cache_loc: torch.Tensor,
263+
last_page_len: torch.Tensor,
264+
):
265+
for plan_params in self.cached_cuda_graph_decode_wrappers:
266+
if plan_params.num_seq == num_seq:
267+
wrapper = self.cached_cuda_graph_decode_wrappers[plan_params]
268+
fast_decode_plan(
269+
wrapper,
270+
cu_num_pages,
271+
cache_loc,
272+
last_page_len,
273+
plan_params.n_heads,
274+
plan_params.n_kv_heads,
275+
plan_params.head_dim,
276+
plan_params.page_size,
277+
q_data_type=plan_params.q_dtype,
278+
kv_data_type=plan_params.kv_dtype,
279+
sm_scale=plan_params.sm_scale,
280+
)
281+
87282
def plan(
88283
self,
89284
qo_indptr: torch.Tensor,
@@ -96,7 +291,9 @@ def plan(
96291
flashinfer.BatchDecodeWithPagedKVCacheWrapper,
97292
]:
98293
# plan decode helper function
99-
def _plan_decode(wrapper: flashinfer.BatchDecodeWithPagedKVCacheWrapper):
294+
def _plan_decode(
295+
wrapper: flashinfer.BatchDecodeWithPagedKVCacheWrapper,
296+
):
100297
wrapper.plan(
101298
kv_page_indptr,
102299
kv_page_indices,
@@ -111,18 +308,23 @@ def _plan_decode(wrapper: flashinfer.BatchDecodeWithPagedKVCacheWrapper):
111308
)
112309

113310
# we want to plan during warm-up of cuda graph capture to ensure we have the plan cached
114-
if cuda_graph_state.in_warm_up() and plan_params not in self.cached_decode_wrappers:
115-
self.cached_decode_wrappers[plan_params] = self._init_decode_wrapper()
116-
_plan_decode(self.cached_decode_wrappers[plan_params])
117-
311+
if (
312+
cuda_graph_state.in_warm_up()
313+
and plan_params not in self.cached_cuda_graph_decode_wrappers
314+
):
315+
# During CUDA graph capture, the metadata tensors provided by auto-deploy are stable.
316+
wrapper = self._init_decode_wrapper(
317+
use_cuda_graph=True,
318+
indptr=kv_page_indptr,
319+
indices=kv_page_indices,
320+
last_page_len=kv_last_page_len,
321+
)
322+
self.cached_cuda_graph_decode_wrappers[plan_params] = wrapper
323+
_plan_decode(self.cached_cuda_graph_decode_wrappers[plan_params])
118324
# check if we are in cuda graph capture and just return the pre-cached decode wrapper
119325
if torch.cuda.is_current_stream_capturing() or cuda_graph_state.in_warm_up():
120326
assert plan_params.is_generate, "Only generate is supported during cuda graph capture."
121-
wrapper = self.cached_decode_wrappers[plan_params]
122-
# copy the metadata to the wrapper to ensure it is up-to-date for graph replay!
123-
wrapper._paged_kv_indptr_buf.copy_(kv_page_indptr)
124-
wrapper._paged_kv_indices_buf.copy_(kv_page_indices)
125-
wrapper._paged_kv_last_page_len_buf.copy_(kv_last_page_len)
327+
wrapper = self.cached_cuda_graph_decode_wrappers[plan_params]
126328
return wrapper
127329

128330
# check for re-planning
@@ -167,14 +369,13 @@ def prepare_flashinfer_metadata(
167369
https://docs.flashinfer.ai/api/prefill.html#flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper.plan
168370
to understand the convention.
169371
"""
170-
# reset the planner
171-
_GlobalFlashInferPlanner.reset()
172-
173372
# retrieve host-side metadata
174373
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
175374
num_seq = num_prefill + num_decode
176375
num_tokens = num_prefill_tokens + num_decode
177376

377+
_GlobalFlashInferPlanner.reset()
378+
178379
qo_indptr = cu_seqlen[: num_seq + 1]
179380

180381
# NOTE: in theory we could easily precompute batch_indices. And positions is just position_ids
@@ -398,6 +599,20 @@ def _init_workspace(si: SequenceInfo) -> torch.Tensor:
398599

399600
return {"workspace_buffer": _init_workspace}
400601

602+
@classmethod
603+
def host_prepare_for_forward(cls, sequence_info: SequenceInfo):
604+
batch_info = sequence_info._input_buffer.get_host_view("batch_info")
605+
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
606+
# Call plan for generate-only batches.
607+
if num_prefill == 0:
608+
_GlobalFlashInferPlanner.plan_generate_only(
609+
num_decode,
610+
sequence_info._input_buffer.get_host_view("cu_num_pages")[: num_decode + 1],
611+
sequence_info._input_buffer.get_host_view("cache_loc"),
612+
sequence_info._input_buffer.get_host_view("last_page_len")[:num_decode],
613+
)
614+
return
615+
401616
@classmethod
402617
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
403618
# Sanity check: layout == "bsnd"

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,8 @@ def _build_input_ids(request) -> Tuple[List[int], List[int], bool]:
668668
if new_tokens is not None:
669669
self.cache_seq_interface.info.rescatter_input_ids(new_tokens.flatten())
670670

671+
self.cache_seq_interface.info.run_host_prepare_for_attention_forward()
672+
671673
self.iter_states["num_ctx_requests"] = num_ctx_requests
672674
self.iter_states["num_ctx_tokens"] = num_ctx_tokens
673675
# TODO: handle extend requests and draft requests for specdec

0 commit comments

Comments
 (0)