Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 29 additions & 102 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,7 @@
"""

from abc import ABC, abstractmethod
from typing import (
Callable,
Dict,
List,
Literal,
Optional,
Protocol,
Sequence,
Set,
Tuple,
Type,
Union,
)
from typing import Dict, List, Literal, Optional, Protocol, Sequence, Set, Tuple, Type, Union

import torch
from pydantic import BaseModel, ConfigDict, Field, field_validator
Expand All @@ -36,6 +24,10 @@
Constant = Union[int, float, str, None]


class PrepareMetadataHostCallable(Protocol):
def __call__(self, **sequence_info_args: torch.Tensor) -> None: ...


class InputBuffer:
"""Manages contiguous memory buffers for efficient host-to-device transfers.

Expand Down Expand Up @@ -388,6 +380,9 @@ class SequenceInfo:
- _mask_scatter_indices: [m_0, m_1, ..., m_{s_total-1}]
Mask scatter indices used by the overlap scheduler to scatter results back.

NOTE: all tensors are also accessible as host tensors with the suffix "_host". For example,
the tensor "batch_info" is accessible as "batch_info_host" on the host.

################################################################################################

Here are a couple of notes to emphasize this notation:
Expand Down Expand Up @@ -508,24 +503,25 @@ def __init__(
# Create the InputBuffer that manages contiguous host and device memory
# Starts on default device; use to() to move to target device
self._input_buffer = InputBuffer(tensor_specs)
self._available_args = set(self._input_buffer.tensor_names) | {
f"{name}_host" for name in self._input_buffer.tensor_names
}

# Initialize args_list from tensor specs
self._args_list: Dict[str, List[int]] = {
name: [0] * numel for name, numel, _ in tensor_specs
}

self._active_args = ("input_ids", "position_ids")
self._shapeable_args = ("input_ids", "position_ids")
# Args that should be returned from host (pinned memory) instead of device in _named_args
self._host_return_args = ("batch_info", "logits_gather_info")
self._shapeable_args = ("input_ids", "position_ids", "input_ids_host", "position_ids_host")
############################################################################################

# EXTRA TENSOR FIELDS ######################################################################
self._extra_args: Dict[str, Optional[torch.Tensor]] = {}
############################################################################################

# HOST PREPARE FOR ATTENTION FORWARD #######################################################
self._host_prepare_functions: set[Callable[[SequenceInfo], None]] = set()
self._host_prepare_functions: List[Tuple[PrepareMetadataHostCallable, List[str]]] = []

# call reset once to set a consistent initial state
self.reset()
Expand Down Expand Up @@ -558,14 +554,13 @@ def _shape_for_forward(self, tnsr: torch.Tensor) -> torch.Tensor:

def _get_arg(self, name: str) -> torch.Tensor:
"""Get the argument from the input buffer either on device or host."""
if name in self._host_return_args:
arg = self._input_buffer.get_host_view(name)
if name.endswith("_host"):
arg = self._input_buffer.get_host_view(name.replace("_host", ""))
else:
arg = self._input_buffer.get_view(name)
return self._shape_for_forward(arg) if name in self._shapeable_args else arg

def _named_args(self, include_extra_args: bool = True) -> Dict[str, torch.Tensor]:
# Build args dict, using host views for _host_return_args, device views otherwise
args = {k: self._get_arg(k) for k in self._active_args}

# check other args to include
Expand All @@ -577,7 +572,7 @@ def _named_args(self, include_extra_args: bool = True) -> Dict[str, torch.Tensor
@property
def available_args(self) -> Set[str]:
"""Return a list of available arguments."""
return set(self._input_buffer.tensor_names)
return self._available_args

@property
def named_args(self) -> Dict[str, torch.Tensor]:
Expand Down Expand Up @@ -697,68 +692,6 @@ def _get_cache_locations_and_pages_per_sequence(
pages_per_seq = [len(p) for p in page_assignments]
return cache_loc_flat, pages_per_seq

# TODO: remove after updating all cached backends
@classmethod
def _get_sanitized_seq_len(
cls, input_or_position_ids: torch.Tensor, seq_len: torch.Tensor
) -> torch.Tensor:
"""Sanitize sequence lengths.

We want to cover the following scenarios with this function:

1. Pre-fill:
input_ids: [1, s_total, ...]
seq_len: [s_0, s_1, ..., s_{b-1}, 0, 0, ..., 0]
---> returns [s_0, s_1, ..., s_{b-1}]
2. Decode:
input_ids: [b, 1, ...]
seq_len: [1, 1, ..., 1, 0, 0, ..., ..., ..., ..., 0]
|---- b ----|--- (max_batch_size - b) ---|
--> returns [1,] * b
3. Decode in Cudagraph:
input_ids: [b_cudagraph, 1, ...]
seq_len: [1, 1, ..., 1, 0, 0, ..., ..., ..., ..., 0]
|---- b ----|--- (max_batch_size - b) ---|

--> returns [1,] * b_cudagraph
Here b <= b_cudagraph. We want to make sure that the seq_len is one-padded to
b_cudagraph.

# TODO: I could see one possible issue with this approach in the future.
# If we have b < b_cudagraph we now one-pad. However, we don't pad the cache location
# information. What could happen is that the for the padded sequences the cache location
# tensors point to allocated pages. This could lead to a situation where we write into
# allocated cache pages polluting the cache of other sequences. Now this is not an issue
# if we write the dummy sequences into unallocated cache pages... One fix could be to
# pad not only the seq len but also pad the cache locations by just repeating the last
# valid cache location in the batch. This would ensure that the dummy sequences just
# repeats valid computation...
"""
_, s = input_or_position_ids.shape[:2]
num_seq = cls._get_sanitized_num_sequences(input_or_position_ids, seq_len)
if s > 1:
return seq_len[:num_seq].clone()
else:
return torch.ones(num_seq, dtype=seq_len.dtype, device=seq_len.device)

@staticmethod
def _get_sanitized_num_sequences(
input_or_position_ids: torch.Tensor, seq_len: torch.Tensor
) -> int:
"""Get number of sequences.

We makes sure that this function is compatible with both torch graph capture and cudagraph.
Both can be a bit temparamental when trying to extract the number of sequences from a tensor
with max_batch_size or max_batch_size*max_seq_len.
"""
b, s = input_or_position_ids.shape[:2]
if s > 1:
num_seq = torch.sum(seq_len > 0)
assert seq_len[num_seq:].sum() == 0, "seq_len should be zero-padded"
else:
num_seq = b
return num_seq

def activate_arg(self, arg_name: str) -> bool:
"""Activate a desired argument.

Expand Down Expand Up @@ -869,7 +802,7 @@ def _store_arg(
self._args_list[name] = tnsr_like.copy()

# Only store to buffer when the argument is active or force_copy is True
if not (name in self._active_args or force_copy):
if not (name in self._active_args or f"{name}_host" in self._active_args or force_copy):
return

# Store to the InputBuffer's pinned host memory
Expand Down Expand Up @@ -1090,12 +1023,12 @@ def rescatter_input_ids(self, ungathered_input_ids: torch.Tensor):
def maybe_gather_and_squeeze_logits(self, logits: torch.Tensor) -> torch.Tensor:
"""Maybe gather the logits if logits have not been gathered yet."""
num_tokens = logits.shape[0] * logits.shape[1]
num_tokens_to_gather, gather_required = self._get_arg("logits_gather_info").tolist()
num_tokens_to_gather, gather_required = self._get_arg("logits_gather_info_host").tolist()
if gather_required and num_tokens_to_gather < num_tokens:
logits = torch.ops.auto_deploy.gather_logits_before_lm_head(
logits,
self._get_arg("logits_gather_indices"),
self._get_arg("logits_gather_info"),
self._get_arg("logits_gather_info_host"),
)
return logits.squeeze(int(self.is_generate))

Expand All @@ -1105,13 +1038,13 @@ def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]:
return list(torch.split(t_squeezed, self.seq_len))

def register_host_prepare_for_attention_forward(
self, host_function: Callable[["SequenceInfo"], None]
self, host_function: PrepareMetadataHostCallable, args: List[str]
):
self._host_prepare_functions.add(host_function)
self._host_prepare_functions.append((host_function, args))

def run_host_prepare_for_attention_forward(self) -> None:
for host_function in self._host_prepare_functions:
host_function(self)
for host_function, args in self._host_prepare_functions:
host_function(**{arg: self._get_arg(arg) for arg in args})


class MHACallable(Protocol):
Expand All @@ -1123,14 +1056,7 @@ def __call__(

class PrepareMetadataCallable(Protocol):
def __call__(
self,
position_ids: torch.Tensor,
seq_len: torch.Tensor,
input_pos: torch.Tensor,
cache_loc: torch.Tensor,
pages_per_seq: torch.Tensor,
slot_idx: torch.Tensor,
page_size: int,
self, *sequence_info_args_and_constants: Union[torch.Tensor, Constant]
) -> List[torch.Tensor]: ...


Expand Down Expand Up @@ -1291,13 +1217,14 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]:
return []

@classmethod
def host_prepare_for_forward(cls, sequence_info: SequenceInfo):
"""Perform host-side preparation for the forward pass for the attention op.
def get_host_prepare_metadata_function(cls) -> Optional[PrepareMetadataHostCallable]:
"""Get function that performs host-side prep for the forward pass for the attention op.

This method is responsible for preparing the attention op for the forward pass.
This function is not expected to be graph capturable or compatible with cuda graphs.
This function is not expected to be graph capturable or compatible with cuda graphs. It can
use any argument from the SequenceInfo interface as input argument to its function.
"""
return
return None


class AttentionRegistry:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def fla_cached_delta_rule(
v: torch.Tensor,
beta: torch.Tensor,
# STANDARD METADATA
batch_info: torch.Tensor,
batch_info_host: torch.Tensor,
cu_seqlen: torch.Tensor,
slot_idx: torch.Tensor,
use_initial_states: torch.Tensor,
Expand All @@ -58,7 +58,7 @@ def fla_cached_delta_rule(
y = torch.empty_like(v, memory_format=torch.contiguous_format)
y_flat = y.view(b * s, num_heads, -1)

num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
num_seq = num_prefill + num_decode

# clean up metadata
Expand Down Expand Up @@ -120,7 +120,7 @@ def fla_cached_delta_rule_fake(
v: torch.Tensor,
beta: torch.Tensor,
# STANDARD METADATA
batch_info: torch.Tensor,
batch_info_host: torch.Tensor,
cu_seqlen: torch.Tensor,
slot_idx: torch.Tensor,
use_initial_states: torch.Tensor,
Expand Down Expand Up @@ -160,7 +160,7 @@ def get_cached_attention_op(cls) -> MHACallable:

@classmethod
def get_standard_metadata_args(cls) -> List[str]:
return ["batch_info", "cu_seqlen", "slot_idx", "use_initial_states"]
return ["batch_info_host", "cu_seqlen", "slot_idx", "use_initial_states"]

@classmethod
def get_cache_initializers(
Expand Down
Loading