Skip to content

Commit b85376f

Browse files
committed
feedback and updates
Signed-off-by: Lucas Liebenwein <[email protected]>
1 parent 8c98841 commit b85376f

File tree

3 files changed

+67
-57
lines changed

3 files changed

+67
-57
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,7 @@
1010
"""
1111

1212
from abc import ABC, abstractmethod
13-
from typing import (
14-
Callable,
15-
Dict,
16-
List,
17-
Literal,
18-
Optional,
19-
Protocol,
20-
Sequence,
21-
Set,
22-
Tuple,
23-
Type,
24-
Union,
25-
)
13+
from typing import Dict, List, Literal, Optional, Protocol, Sequence, Set, Tuple, Type, Union
2614

2715
import torch
2816
from pydantic import BaseModel, ConfigDict, Field, field_validator
@@ -36,6 +24,10 @@
3624
Constant = Union[int, float, str, None]
3725

3826

27+
class PrepareMetadataHostCallable(Protocol):
28+
def __call__(self, **sequence_info_args: torch.Tensor) -> None: ...
29+
30+
3931
class InputBuffer:
4032
"""Manages contiguous memory buffers for efficient host-to-device transfers.
4133
@@ -388,6 +380,9 @@ class SequenceInfo:
388380
- _mask_scatter_indices: [m_0, m_1, ..., m_{s_total-1}]
389381
Mask scatter indices used by the overlap scheduler to scatter results back.
390382
383+
NOTE: all tensors are also accessible as host tensors with the suffix "_host". For example,
384+
the tensor "batch_info" is accessible as "batch_info_host" on the host.
385+
391386
################################################################################################
392387
393388
Here are a couple of notes to emphasize this notation:
@@ -526,7 +521,7 @@ def __init__(
526521
############################################################################################
527522

528523
# HOST PREPARE FOR ATTENTION FORWARD #######################################################
529-
self._host_prepare_functions: set[Callable[[SequenceInfo], None]] = set()
524+
self._host_prepare_functions: List[Tuple[PrepareMetadataHostCallable, List[str]]] = []
530525

531526
# call reset once to set a consistent initial state
532527
self.reset()
@@ -1043,13 +1038,13 @@ def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]:
10431038
return list(torch.split(t_squeezed, self.seq_len))
10441039

10451040
def register_host_prepare_for_attention_forward(
1046-
self, host_function: Callable[["SequenceInfo"], None]
1041+
self, host_function: PrepareMetadataHostCallable, args: List[str]
10471042
):
1048-
self._host_prepare_functions.add(host_function)
1043+
self._host_prepare_functions.append((host_function, args))
10491044

10501045
def run_host_prepare_for_attention_forward(self) -> None:
1051-
for host_function in self._host_prepare_functions:
1052-
host_function(self)
1046+
for host_function, args in self._host_prepare_functions:
1047+
host_function(**{arg: self._get_arg(arg) for arg in args})
10531048

10541049

10551050
class MHACallable(Protocol):
@@ -1061,14 +1056,7 @@ def __call__(
10611056

10621057
class PrepareMetadataCallable(Protocol):
10631058
def __call__(
1064-
self,
1065-
position_ids: torch.Tensor,
1066-
seq_len: torch.Tensor,
1067-
input_pos: torch.Tensor,
1068-
cache_loc: torch.Tensor,
1069-
pages_per_seq: torch.Tensor,
1070-
slot_idx: torch.Tensor,
1071-
page_size: int,
1059+
self, *sequence_info_args_and_constants: Union[torch.Tensor, Constant]
10721060
) -> List[torch.Tensor]: ...
10731061

10741062

@@ -1229,13 +1217,14 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]:
12291217
return []
12301218

12311219
@classmethod
1232-
def host_prepare_for_forward(cls, sequence_info: SequenceInfo):
1233-
"""Perform host-side preparation for the forward pass for the attention op.
1220+
def get_host_prepare_metadata_function(cls) -> Optional[PrepareMetadataHostCallable]:
1221+
"""Get function that performs host-side prep for the forward pass for the attention op.
12341222
12351223
This method is responsible for preparing the attention op for the forward pass.
1236-
This function is not expected to be graph capturable or compatible with cuda graphs.
1224+
This function is not expected to be graph capturable or compatible with cuda graphs. It can
1225+
use any argument from the SequenceInfo interface as input argument to its function.
12371226
"""
1238-
return
1227+
return None
12391228

12401229

12411230
class AttentionRegistry:

tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
Constant,
2222
MHACallable,
2323
PrepareMetadataCallable,
24+
PrepareMetadataHostCallable,
2425
SequenceInfo,
2526
)
2627

@@ -183,7 +184,6 @@ class PlanParams:
183184
n_kv_heads: int
184185
head_dim: int
185186
num_seq: int
186-
is_generate: bool
187187
page_size: int
188188
q_dtype: torch.dtype
189189
kv_dtype: torch.dtype
@@ -289,12 +289,17 @@ def plan_prefill(
289289
kv_page_indices: torch.Tensor,
290290
kv_last_page_len_host: torch.Tensor,
291291
kv_lens_arr_host: torch.Tensor,
292-
seq_len_host: torch.Tensor,
293292
plan_params: PlanParams,
294293
) -> None:
295294
# check for re-planning
296295
if plan_params != self.plan_params_prefill:
297296
# plan prefill
297+
# NOTE (lucaslie): we use host versions here. the plan actually needs both (host+device)
298+
# version. Unfortunately, there is no good way to access the plan API and provide both
299+
# although we have both available. I have decided to use the host versions here to
300+
# ensure non-blocking invocation of plan, whereas the other way around would trigger a
301+
# blocking copy to cpu. This way we trigger a non-blocking copy to device (note that
302+
# this is safe since we do have pinned CPU memory for all our host-side arguments).
298303
self.prefill_wrapper.plan(
299304
qo_indptr_host,
300305
kv_page_indptr_host,
@@ -308,7 +313,6 @@ def plan_prefill(
308313
q_data_type=plan_params.q_dtype,
309314
kv_data_type=plan_params.kv_dtype,
310315
sm_scale=plan_params.sm_scale,
311-
# max_token_per_sequence=max(seq_len_host).item(),
312316
seq_lens=kv_lens_arr_host,
313317
)
314318
self.plan_params_prefill = plan_params
@@ -359,7 +363,6 @@ def _plan_decode(
359363
_plan_decode(self.cached_cuda_graph_decode_wrappers[plan_params])
360364
# check if we are in cuda graph capture and just return the pre-cached decode wrapper
361365
if torch.cuda.is_current_stream_capturing() or cuda_graph_state.in_warm_up():
362-
assert plan_params.is_generate, "Only generate is supported during cuda graph capture."
363366
wrapper = self.cached_cuda_graph_decode_wrappers[plan_params]
364367
return wrapper
365368

@@ -423,6 +426,23 @@ def prepare_flashinfer_metadata_fake(
423426
)
424427

425428

429+
def prepare_flashinfer_metadata_host(
430+
batch_info_host: torch.Tensor,
431+
cu_num_pages_host: torch.Tensor,
432+
cache_loc_host: torch.Tensor,
433+
last_page_len_host: torch.Tensor,
434+
) -> None:
435+
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
436+
437+
if num_prefill == 0:
438+
_GlobalFlashInferPlanner.plan_generate_only(
439+
num_decode,
440+
cu_num_pages_host[: num_decode + 1],
441+
cache_loc_host,
442+
last_page_len_host[:num_decode],
443+
)
444+
445+
426446
@torch.library.custom_op("auto_deploy::flashinfer_attention_mha_with_cache", mutates_args=())
427447
def flashinfer_mha_with_cache(
428448
# Q, K, V
@@ -438,7 +458,6 @@ def flashinfer_mha_with_cache(
438458
last_page_len: torch.Tensor,
439459
last_page_len_host: torch.Tensor,
440460
seq_len_with_cache_host: torch.Tensor,
441-
seq_len_host: torch.Tensor,
442461
# EXTRA METADATA
443462
flashinfer_batch_indices: torch.Tensor,
444463
flashinfer_positions: torch.Tensor,
@@ -502,7 +521,6 @@ def flashinfer_mha_with_cache(
502521
n_kv_heads=n_kv_heads,
503522
head_dim=head_dim,
504523
num_seq=num_prefill,
505-
is_generate=False,
506524
page_size=k_cache.shape[1],
507525
q_dtype=q_prefill.dtype,
508526
kv_dtype=k_cache.dtype,
@@ -515,7 +533,6 @@ def flashinfer_mha_with_cache(
515533
kv_page_indices=cache_loc,
516534
kv_last_page_len_host=last_page_len_host[:num_prefill],
517535
kv_lens_arr_host=seq_len_with_cache_host[:num_prefill],
518-
seq_len_host=seq_len_host[:num_prefill],
519536
plan_params=pp_prefill,
520537
)
521538

@@ -539,7 +556,6 @@ def flashinfer_mha_with_cache(
539556
n_kv_heads=n_kv_heads,
540557
head_dim=head_dim,
541558
num_seq=num_decode,
542-
is_generate=True,
543559
page_size=k_cache.shape[1],
544560
q_dtype=q_decode.dtype,
545561
kv_dtype=k_cache.dtype,
@@ -584,7 +600,6 @@ def flashinfer_mha_with_cache_fake(
584600
last_page_len: torch.Tensor,
585601
last_page_len_host: torch.Tensor,
586602
seq_len_with_cache_host: torch.Tensor,
587-
seq_len_host: torch.Tensor,
588603
# EXTRA METADATA
589604
flashinfer_batch_indices: torch.Tensor,
590605
flashinfer_positions: torch.Tensor,
@@ -642,7 +657,6 @@ def get_standard_metadata_args(cls) -> List[str]:
642657
"last_page_len",
643658
"last_page_len_host",
644659
"seq_len_with_cache_host",
645-
"seq_len_host",
646660
]
647661

648662
@classmethod
@@ -684,18 +698,8 @@ def _init_workspace(si: SequenceInfo) -> torch.Tensor:
684698
return {"workspace_buffer": _init_workspace}
685699

686700
@classmethod
687-
def host_prepare_for_forward(cls, sequence_info: SequenceInfo):
688-
batch_info = sequence_info._input_buffer.get_host_view("batch_info")
689-
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
690-
# Call plan for generate-only batches.
691-
if num_prefill == 0:
692-
_GlobalFlashInferPlanner.plan_generate_only(
693-
num_decode,
694-
sequence_info._input_buffer.get_host_view("cu_num_pages")[: num_decode + 1],
695-
sequence_info._input_buffer.get_host_view("cache_loc"),
696-
sequence_info._input_buffer.get_host_view("last_page_len")[:num_decode],
697-
)
698-
return
701+
def get_host_prepare_metadata_function(cls) -> Optional[PrepareMetadataHostCallable]:
702+
return prepare_flashinfer_metadata_host
699703

700704
@classmethod
701705
def get_constants(cls, source_attn_node: Node) -> List[Constant]:

tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Graph transformation to automatically add kv cache into fused MHA op."""
22

3+
import inspect
34
import operator
45
from typing import Dict, List, Optional, Tuple, Type
56

@@ -106,6 +107,23 @@ def _process_metadata_extra(
106107
gm, prep_meta_op, inputs_for_prep_meta, const_args, num_meta_out
107108
)
108109

110+
def _process_metadata_host(self, cm: CachedSequenceInterface):
111+
"""Process the host-side prepare metadata function."""
112+
prep_meta_host_op = self.attn_descriptor.get_host_prepare_metadata_function()
113+
if prep_meta_host_op is None:
114+
return
115+
116+
# analyze the args of the host-side prepare metadata function using inspect
117+
sig = inspect.signature(prep_meta_host_op)
118+
args = sig.parameters.keys()
119+
120+
# check if all args are available in the cached sequence interface
121+
unavailable_args = args - cm.info.available_args
122+
assert not unavailable_args, f"Missing args in SequenceInfo: {unavailable_args=}"
123+
124+
# add the host-side prepare metadata function to the graph
125+
cm.info.register_host_prepare_for_attention_forward(prep_meta_host_op, list(args))
126+
109127
def _process_cache_node(self, gm: GraphModule, cache_name: str) -> Node:
110128
"""Process the cache nodes by inserting a cached attention replacement op."""
111129
return add_graph_input(gm, cache_name)
@@ -173,6 +191,9 @@ def _apply(
173191
# insert metadata computation and extract each argument as a node
174192
meta_nodes_extra = self._process_metadata_extra(gm, cm, source_attn_nodes[0])
175193

194+
# Register host-side prepare_metadata function for attention descriptor.
195+
self._process_metadata_host(cm)
196+
176197
buffer_in_lookup: Dict[str, Node] = {}
177198

178199
# replace fused attention node with attention node that has kv cache
@@ -213,11 +234,7 @@ def _apply(
213234
buffer_in_nodes,
214235
constants,
215236
)
216-
# Attention descriptor should register its host function with SequenceInfo.
217-
# This function will be called before graph invocation.
218-
cm.info.register_host_prepare_for_attention_forward(
219-
attn_descriptor.host_prepare_for_forward
220-
)
237+
221238
num_cached_attn_replacements += 1
222239

223240
info = TransformInfo(

0 commit comments

Comments
 (0)