Skip to content

Commit 1bbe71b

Browse files
authored
[#10244][feat] AutoDeploy: separate prefill/decode in flashinfer (#10252)
Signed-off-by: Lucas Liebenwein <[email protected]>
1 parent 9085021 commit 1bbe71b

25 files changed

+441
-330
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Lines changed: 29 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,7 @@
1010
"""
1111

1212
from abc import ABC, abstractmethod
13-
from typing import (
14-
Callable,
15-
Dict,
16-
List,
17-
Literal,
18-
Optional,
19-
Protocol,
20-
Sequence,
21-
Set,
22-
Tuple,
23-
Type,
24-
Union,
25-
)
13+
from typing import Dict, List, Literal, Optional, Protocol, Sequence, Set, Tuple, Type, Union
2614

2715
import torch
2816
from pydantic import BaseModel, ConfigDict, Field, field_validator
@@ -36,6 +24,10 @@
3624
Constant = Union[int, float, str, None]
3725

3826

27+
class PrepareMetadataHostCallable(Protocol):
28+
def __call__(self, **sequence_info_args: torch.Tensor) -> None: ...
29+
30+
3931
class InputBuffer:
4032
"""Manages contiguous memory buffers for efficient host-to-device transfers.
4133
@@ -388,6 +380,9 @@ class SequenceInfo:
388380
- _mask_scatter_indices: [m_0, m_1, ..., m_{s_total-1}]
389381
Mask scatter indices used by the overlap scheduler to scatter results back.
390382
383+
NOTE: all tensors are also accessible as host tensors with the suffix "_host". For example,
384+
the tensor "batch_info" is accessible as "batch_info_host" on the host.
385+
391386
################################################################################################
392387
393388
Here are a couple of notes to emphasize this notation:
@@ -508,24 +503,25 @@ def __init__(
508503
# Create the InputBuffer that manages contiguous host and device memory
509504
# Starts on default device; use to() to move to target device
510505
self._input_buffer = InputBuffer(tensor_specs)
506+
self._available_args = set(self._input_buffer.tensor_names) | {
507+
f"{name}_host" for name in self._input_buffer.tensor_names
508+
}
511509

512510
# Initialize args_list from tensor specs
513511
self._args_list: Dict[str, List[int]] = {
514512
name: [0] * numel for name, numel, _ in tensor_specs
515513
}
516514

517515
self._active_args = ("input_ids", "position_ids")
518-
self._shapeable_args = ("input_ids", "position_ids")
519-
# Args that should be returned from host (pinned memory) instead of device in _named_args
520-
self._host_return_args = ("batch_info", "logits_gather_info")
516+
self._shapeable_args = ("input_ids", "position_ids", "input_ids_host", "position_ids_host")
521517
############################################################################################
522518

523519
# EXTRA TENSOR FIELDS ######################################################################
524520
self._extra_args: Dict[str, Optional[torch.Tensor]] = {}
525521
############################################################################################
526522

527523
# HOST PREPARE FOR ATTENTION FORWARD #######################################################
528-
self._host_prepare_functions: set[Callable[[SequenceInfo], None]] = set()
524+
self._host_prepare_functions: List[Tuple[PrepareMetadataHostCallable, List[str]]] = []
529525

530526
# call reset once to set a consistent initial state
531527
self.reset()
@@ -558,14 +554,13 @@ def _shape_for_forward(self, tnsr: torch.Tensor) -> torch.Tensor:
558554

559555
def _get_arg(self, name: str) -> torch.Tensor:
560556
"""Get the argument from the input buffer either on device or host."""
561-
if name in self._host_return_args:
562-
arg = self._input_buffer.get_host_view(name)
557+
if name.endswith("_host"):
558+
arg = self._input_buffer.get_host_view(name.replace("_host", ""))
563559
else:
564560
arg = self._input_buffer.get_view(name)
565561
return self._shape_for_forward(arg) if name in self._shapeable_args else arg
566562

567563
def _named_args(self, include_extra_args: bool = True) -> Dict[str, torch.Tensor]:
568-
# Build args dict, using host views for _host_return_args, device views otherwise
569564
args = {k: self._get_arg(k) for k in self._active_args}
570565

571566
# check other args to include
@@ -577,7 +572,7 @@ def _named_args(self, include_extra_args: bool = True) -> Dict[str, torch.Tensor
577572
@property
578573
def available_args(self) -> Set[str]:
579574
"""Return a list of available arguments."""
580-
return set(self._input_buffer.tensor_names)
575+
return self._available_args
581576

582577
@property
583578
def named_args(self) -> Dict[str, torch.Tensor]:
@@ -697,68 +692,6 @@ def _get_cache_locations_and_pages_per_sequence(
697692
pages_per_seq = [len(p) for p in page_assignments]
698693
return cache_loc_flat, pages_per_seq
699694

700-
# TODO: remove after updating all cached backends
701-
@classmethod
702-
def _get_sanitized_seq_len(
703-
cls, input_or_position_ids: torch.Tensor, seq_len: torch.Tensor
704-
) -> torch.Tensor:
705-
"""Sanitize sequence lengths.
706-
707-
We want to cover the following scenarios with this function:
708-
709-
1. Pre-fill:
710-
input_ids: [1, s_total, ...]
711-
seq_len: [s_0, s_1, ..., s_{b-1}, 0, 0, ..., 0]
712-
---> returns [s_0, s_1, ..., s_{b-1}]
713-
2. Decode:
714-
input_ids: [b, 1, ...]
715-
seq_len: [1, 1, ..., 1, 0, 0, ..., ..., ..., ..., 0]
716-
|---- b ----|--- (max_batch_size - b) ---|
717-
--> returns [1,] * b
718-
3. Decode in Cudagraph:
719-
input_ids: [b_cudagraph, 1, ...]
720-
seq_len: [1, 1, ..., 1, 0, 0, ..., ..., ..., ..., 0]
721-
|---- b ----|--- (max_batch_size - b) ---|
722-
723-
--> returns [1,] * b_cudagraph
724-
Here b <= b_cudagraph. We want to make sure that the seq_len is one-padded to
725-
b_cudagraph.
726-
727-
# TODO: I could see one possible issue with this approach in the future.
728-
# If we have b < b_cudagraph we now one-pad. However, we don't pad the cache location
729-
# information. What could happen is that the for the padded sequences the cache location
730-
# tensors point to allocated pages. This could lead to a situation where we write into
731-
# allocated cache pages polluting the cache of other sequences. Now this is not an issue
732-
# if we write the dummy sequences into unallocated cache pages... One fix could be to
733-
# pad not only the seq len but also pad the cache locations by just repeating the last
734-
# valid cache location in the batch. This would ensure that the dummy sequences just
735-
# repeats valid computation...
736-
"""
737-
_, s = input_or_position_ids.shape[:2]
738-
num_seq = cls._get_sanitized_num_sequences(input_or_position_ids, seq_len)
739-
if s > 1:
740-
return seq_len[:num_seq].clone()
741-
else:
742-
return torch.ones(num_seq, dtype=seq_len.dtype, device=seq_len.device)
743-
744-
@staticmethod
745-
def _get_sanitized_num_sequences(
746-
input_or_position_ids: torch.Tensor, seq_len: torch.Tensor
747-
) -> int:
748-
"""Get number of sequences.
749-
750-
We makes sure that this function is compatible with both torch graph capture and cudagraph.
751-
Both can be a bit temparamental when trying to extract the number of sequences from a tensor
752-
with max_batch_size or max_batch_size*max_seq_len.
753-
"""
754-
b, s = input_or_position_ids.shape[:2]
755-
if s > 1:
756-
num_seq = torch.sum(seq_len > 0)
757-
assert seq_len[num_seq:].sum() == 0, "seq_len should be zero-padded"
758-
else:
759-
num_seq = b
760-
return num_seq
761-
762695
def activate_arg(self, arg_name: str) -> bool:
763696
"""Activate a desired argument.
764697
@@ -869,7 +802,7 @@ def _store_arg(
869802
self._args_list[name] = tnsr_like.copy()
870803

871804
# Only store to buffer when the argument is active or force_copy is True
872-
if not (name in self._active_args or force_copy):
805+
if not (name in self._active_args or f"{name}_host" in self._active_args or force_copy):
873806
return
874807

875808
# Store to the InputBuffer's pinned host memory
@@ -1090,12 +1023,12 @@ def rescatter_input_ids(self, ungathered_input_ids: torch.Tensor):
10901023
def maybe_gather_and_squeeze_logits(self, logits: torch.Tensor) -> torch.Tensor:
10911024
"""Maybe gather the logits if logits have not been gathered yet."""
10921025
num_tokens = logits.shape[0] * logits.shape[1]
1093-
num_tokens_to_gather, gather_required = self._get_arg("logits_gather_info").tolist()
1026+
num_tokens_to_gather, gather_required = self._get_arg("logits_gather_info_host").tolist()
10941027
if gather_required and num_tokens_to_gather < num_tokens:
10951028
logits = torch.ops.auto_deploy.gather_logits_before_lm_head(
10961029
logits,
10971030
self._get_arg("logits_gather_indices"),
1098-
self._get_arg("logits_gather_info"),
1031+
self._get_arg("logits_gather_info_host"),
10991032
)
11001033
return logits.squeeze(int(self.is_generate))
11011034

@@ -1105,13 +1038,13 @@ def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]:
11051038
return list(torch.split(t_squeezed, self.seq_len))
11061039

11071040
def register_host_prepare_for_attention_forward(
1108-
self, host_function: Callable[["SequenceInfo"], None]
1041+
self, host_function: PrepareMetadataHostCallable, args: List[str]
11091042
):
1110-
self._host_prepare_functions.add(host_function)
1043+
self._host_prepare_functions.append((host_function, args))
11111044

11121045
def run_host_prepare_for_attention_forward(self) -> None:
1113-
for host_function in self._host_prepare_functions:
1114-
host_function(self)
1046+
for host_function, args in self._host_prepare_functions:
1047+
host_function(**{arg: self._get_arg(arg) for arg in args})
11151048

11161049

11171050
class MHACallable(Protocol):
@@ -1123,14 +1056,7 @@ def __call__(
11231056

11241057
class PrepareMetadataCallable(Protocol):
11251058
def __call__(
1126-
self,
1127-
position_ids: torch.Tensor,
1128-
seq_len: torch.Tensor,
1129-
input_pos: torch.Tensor,
1130-
cache_loc: torch.Tensor,
1131-
pages_per_seq: torch.Tensor,
1132-
slot_idx: torch.Tensor,
1133-
page_size: int,
1059+
self, *sequence_info_args_and_constants: Union[torch.Tensor, Constant]
11341060
) -> List[torch.Tensor]: ...
11351061

11361062

@@ -1291,13 +1217,14 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]:
12911217
return []
12921218

12931219
@classmethod
1294-
def host_prepare_for_forward(cls, sequence_info: SequenceInfo):
1295-
"""Perform host-side preparation for the forward pass for the attention op.
1220+
def get_host_prepare_metadata_function(cls) -> Optional[PrepareMetadataHostCallable]:
1221+
"""Get function that performs host-side prep for the forward pass for the attention op.
12961222
12971223
This method is responsible for preparing the attention op for the forward pass.
1298-
This function is not expected to be graph capturable or compatible with cuda graphs.
1224+
This function is not expected to be graph capturable or compatible with cuda graphs. It can
1225+
use any argument from the SequenceInfo interface as input argument to its function.
12991226
"""
1300-
return
1227+
return None
13011228

13021229

13031230
class AttentionRegistry:

tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def fla_cached_delta_rule(
3535
v: torch.Tensor,
3636
beta: torch.Tensor,
3737
# STANDARD METADATA
38-
batch_info: torch.Tensor,
38+
batch_info_host: torch.Tensor,
3939
cu_seqlen: torch.Tensor,
4040
slot_idx: torch.Tensor,
4141
use_initial_states: torch.Tensor,
@@ -58,7 +58,7 @@ def fla_cached_delta_rule(
5858
y = torch.empty_like(v, memory_format=torch.contiguous_format)
5959
y_flat = y.view(b * s, num_heads, -1)
6060

61-
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
61+
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
6262
num_seq = num_prefill + num_decode
6363

6464
# clean up metadata
@@ -120,7 +120,7 @@ def fla_cached_delta_rule_fake(
120120
v: torch.Tensor,
121121
beta: torch.Tensor,
122122
# STANDARD METADATA
123-
batch_info: torch.Tensor,
123+
batch_info_host: torch.Tensor,
124124
cu_seqlen: torch.Tensor,
125125
slot_idx: torch.Tensor,
126126
use_initial_states: torch.Tensor,
@@ -160,7 +160,7 @@ def get_cached_attention_op(cls) -> MHACallable:
160160

161161
@classmethod
162162
def get_standard_metadata_args(cls) -> List[str]:
163-
return ["batch_info", "cu_seqlen", "slot_idx", "use_initial_states"]
163+
return ["batch_info_host", "cu_seqlen", "slot_idx", "use_initial_states"]
164164

165165
@classmethod
166166
def get_cache_initializers(

0 commit comments

Comments
 (0)