Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,19 @@
"""

from abc import ABC, abstractmethod
from typing import Dict, List, Literal, Optional, Protocol, Sequence, Set, Tuple, Type, Union
from typing import (
Callable,
Dict,
List,
Literal,
Optional,
Protocol,
Sequence,
Set,
Tuple,
Type,
Union,
)

import torch
from pydantic import BaseModel, ConfigDict, Field, field_validator
Expand Down Expand Up @@ -512,6 +524,9 @@ def __init__(
self._extra_args: Dict[str, Optional[torch.Tensor]] = {}
############################################################################################

# HOST PREPARE FOR ATTENTION FORWARD #######################################################
self._host_prepare_functions: set[Callable[[SequenceInfo], None]] = set()

# call reset once to set a consistent initial state
self.reset()

Expand Down Expand Up @@ -1089,6 +1104,15 @@ def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]:
t_squeezed = t_nested.squeeze(int(self.is_generate))
return list(torch.split(t_squeezed, self.seq_len))

def register_host_prepare_for_attention_forward(
self, host_function: Callable[["SequenceInfo"], None]
):
self._host_prepare_functions.add(host_function)

def run_host_prepare_for_attention_forward(self) -> None:
for host_function in self._host_prepare_functions:
host_function(self)


class MHACallable(Protocol):
def __call__(
Expand Down Expand Up @@ -1266,6 +1290,15 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]:
"""
return []

@classmethod
def host_prepare_for_forward(cls, sequence_info: SequenceInfo):
"""Perform host-side preparation for the forward pass for the attention op.

This method is responsible for preparing the attention op for the forward pass.
This function is not expected to be graph capturable or compatible with cuda graphs.
"""
return


class AttentionRegistry:
"""A simple registry to look up different attention implementations."""
Expand Down
253 changes: 234 additions & 19 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,156 @@
)


# TODO: remove this when flashinfer version is updated to >0.5
def fast_decode_plan(
wrapper: flashinfer.BatchDecodeWithPagedKVCacheWrapper,
indptr: torch.Tensor,
indices: torch.Tensor,
last_page_len: torch.Tensor,
num_qo_heads: int,
num_kv_heads: int,
head_dim: int,
page_size: int,
pos_encoding_mode: str = "NONE",
window_left: int = -1,
logits_soft_cap: Optional[float] = None,
q_data_type: Optional[Union[str, torch.dtype]] = None,
kv_data_type: Optional[Union[str, torch.dtype]] = None,
data_type: Optional[Union[str, torch.dtype]] = None,
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
non_blocking: bool = True,
fixed_split_size: Optional[int] = None,
disable_split_kv: bool = False,
global_override_indptr_cpu: Optional[torch.Tensor] = None,
) -> None:
"""
Copied from flashinfer.decode.fast_decode_plan in flashinfer version >0.5.
Does not exist in flashinfer version 0.3.1, hence copied here.
"""
batch_size = len(last_page_len)
if logits_soft_cap is None:
logits_soft_cap = 0.0

# Handle data types consistently
if data_type is not None:
if q_data_type is None:
q_data_type = data_type
if kv_data_type is None:
kv_data_type = data_type
elif q_data_type is None:
q_data_type = "float16"

if kv_data_type is None:
kv_data_type = q_data_type

if wrapper.use_tensor_cores:
qo_indptr_host = torch.arange(batch_size + 1, dtype=torch.int32, device="cpu")
# Here we set fixed_split_size to -1 to avoid the assertion error in flashinfer's plan function
if fixed_split_size is None:
fixed_split_size = -1

if wrapper.is_cuda_graph_enabled:
if batch_size != wrapper._fixed_batch_size:
raise ValueError(
"The batch size should be fixed in cudagraph mode, the runtime batch size {} "
" mismatches the batch size set during initialization {}".format(
batch_size, wrapper._fixed_batch_size
)
)
if len(indices) > len(wrapper._paged_kv_indices_buf):
raise ValueError(
"The size of indices should be less than or equal to the allocated buffer"
)
else:
wrapper._paged_kv_indptr_buf = indptr
wrapper._paged_kv_indices_buf = indices
wrapper._paged_kv_last_page_len_buf = last_page_len
if wrapper.use_tensor_cores:
wrapper._qo_indptr_buf = qo_indptr_host.to(wrapper.device, non_blocking=non_blocking)

# Create empty tensors for dtype info if needed
empty_q_data = torch.empty(
0,
dtype=(getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type),
device=wrapper.device,
)

empty_kv_cache = torch.empty(
0,
dtype=(getattr(torch, kv_data_type) if isinstance(kv_data_type, str) else kv_data_type),
device=wrapper.device,
)

indptr_host = (
global_override_indptr_cpu if global_override_indptr_cpu is not None else indptr.cpu()
)

with torch.cuda.device(wrapper.device):
if wrapper.use_tensor_cores:
# ALSO convert last_page_len to CPU
if page_size == 1:
# When page size is 1, last_page_len is always 1.
# Directly construct the host tensor rather than executing a device-to-host copy.
last_page_len_host = torch.ones((batch_size,), dtype=torch.int32, device="cpu")
else:
last_page_len_host = last_page_len.cpu()

kv_lens_arr_host = flashinfer.get_seq_lens(indptr_host, last_page_len_host, page_size)

try:
# Make sure we pass exactly 15 arguments for tensor core version
wrapper._plan_info = wrapper._cached_module.plan(
wrapper._float_workspace_buffer,
wrapper._int_workspace_buffer,
wrapper._pin_memory_int_workspace_buffer,
qo_indptr_host,
indptr_host,
kv_lens_arr_host,
batch_size, # total_num_rows
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
wrapper.is_cuda_graph_enabled,
head_dim,
head_dim,
False, # causal
)
except Exception as e:
raise RuntimeError(f"Error in standard plan: {e}") from e
else:
try:
# Make sure we pass exactly 15 arguments for standard version
wrapper._plan_info = wrapper._cached_module.plan(
wrapper._float_workspace_buffer,
wrapper._int_workspace_buffer,
wrapper._pin_memory_int_workspace_buffer,
indptr_host,
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
wrapper.is_cuda_graph_enabled,
window_left,
logits_soft_cap,
head_dim,
head_dim,
empty_q_data,
empty_kv_cache,
)
except Exception as e:
raise RuntimeError(f"Error in standard plan: {e}") from e

wrapper._pos_encoding_mode = pos_encoding_mode
wrapper._window_left = window_left
wrapper._logits_soft_cap = logits_soft_cap
wrapper._sm_scale = sm_scale
wrapper._rope_scale = rope_scale
wrapper._rope_theta = rope_theta


@dataclass
class PlanParams:
"""Parameters that affect the flashinfer execution plan."""
Expand Down Expand Up @@ -52,21 +202,42 @@ class _FlashInferPlanner:
workspace_buffer: Optional[torch.Tensor]
prefill_wrapper: Optional[flashinfer.BatchPrefillWithPagedKVCacheWrapper]
decode_wrapper: Optional[flashinfer.BatchDecodeWithPagedKVCacheWrapper]
cached_decode_wrappers: Dict[PlanParams, flashinfer.BatchDecodeWithPagedKVCacheWrapper]
cached_cuda_graph_decode_wrappers: Dict[
PlanParams, flashinfer.BatchDecodeWithPagedKVCacheWrapper
]
plan_params: Optional[PlanParams]

def __init__(self):
self.workspace_buffer = None
self.prefill_wrapper = None
self.decode_wrapper = None
self.cached_decode_wrappers = {}
self.cached_cuda_graph_decode_wrappers = {}
self.plan_params = None

def _init_decode_wrapper(self):
def _init_decode_wrapper(
self,
use_cuda_graph: bool = False,
indptr: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
last_page_len: Optional[torch.Tensor] = None,
):
assert self.workspace_buffer is not None
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer, "NHD", use_tensor_cores=True
)
if use_cuda_graph:
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
use_cuda_graph=True,
paged_kv_indptr_buffer=indptr,
paged_kv_indices_buffer=indices,
paged_kv_last_page_len_buffer=last_page_len,
use_tensor_cores=True,
)
else:
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
use_tensor_cores=True,
)

def init_workspace(self, workspace_buffer: torch.Tensor):
self.__init__() # reset all state
Expand All @@ -84,6 +255,30 @@ def init_workspace(self, workspace_buffer: torch.Tensor):
def reset(self) -> None:
self.plan_params = None

def plan_generate_only(
self,
num_seq: int,
cu_num_pages: torch.Tensor,
cache_loc: torch.Tensor,
last_page_len: torch.Tensor,
):
for plan_params in self.cached_cuda_graph_decode_wrappers:
if plan_params.num_seq == num_seq:
wrapper = self.cached_cuda_graph_decode_wrappers[plan_params]
fast_decode_plan(
wrapper,
cu_num_pages,
cache_loc,
last_page_len,
plan_params.n_heads,
plan_params.n_kv_heads,
plan_params.head_dim,
plan_params.page_size,
q_data_type=plan_params.q_dtype,
kv_data_type=plan_params.kv_dtype,
sm_scale=plan_params.sm_scale,
)

def plan(
self,
qo_indptr: torch.Tensor,
Expand All @@ -96,7 +291,9 @@ def plan(
flashinfer.BatchDecodeWithPagedKVCacheWrapper,
]:
# plan decode helper function
def _plan_decode(wrapper: flashinfer.BatchDecodeWithPagedKVCacheWrapper):
def _plan_decode(
wrapper: flashinfer.BatchDecodeWithPagedKVCacheWrapper,
):
wrapper.plan(
kv_page_indptr,
kv_page_indices,
Expand All @@ -111,18 +308,23 @@ def _plan_decode(wrapper: flashinfer.BatchDecodeWithPagedKVCacheWrapper):
)

# we want to plan during warm-up of cuda graph capture to ensure we have the plan cached
if cuda_graph_state.in_warm_up() and plan_params not in self.cached_decode_wrappers:
self.cached_decode_wrappers[plan_params] = self._init_decode_wrapper()
_plan_decode(self.cached_decode_wrappers[plan_params])

if (
cuda_graph_state.in_warm_up()
and plan_params not in self.cached_cuda_graph_decode_wrappers
):
# During CUDA graph capture, the metadata tensors provided by auto-deploy are stable.
wrapper = self._init_decode_wrapper(
use_cuda_graph=True,
indptr=kv_page_indptr,
indices=kv_page_indices,
last_page_len=kv_last_page_len,
)
self.cached_cuda_graph_decode_wrappers[plan_params] = wrapper
_plan_decode(self.cached_cuda_graph_decode_wrappers[plan_params])
# check if we are in cuda graph capture and just return the pre-cached decode wrapper
if torch.cuda.is_current_stream_capturing() or cuda_graph_state.in_warm_up():
assert plan_params.is_generate, "Only generate is supported during cuda graph capture."
wrapper = self.cached_decode_wrappers[plan_params]
# copy the metadata to the wrapper to ensure it is up-to-date for graph replay!
wrapper._paged_kv_indptr_buf.copy_(kv_page_indptr)
wrapper._paged_kv_indices_buf.copy_(kv_page_indices)
wrapper._paged_kv_last_page_len_buf.copy_(kv_last_page_len)
wrapper = self.cached_cuda_graph_decode_wrappers[plan_params]
return wrapper

# check for re-planning
Expand Down Expand Up @@ -167,14 +369,13 @@ def prepare_flashinfer_metadata(
https://docs.flashinfer.ai/api/prefill.html#flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper.plan
to understand the convention.
"""
# reset the planner
_GlobalFlashInferPlanner.reset()

# retrieve host-side metadata
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
num_seq = num_prefill + num_decode
num_tokens = num_prefill_tokens + num_decode

_GlobalFlashInferPlanner.reset()

qo_indptr = cu_seqlen[: num_seq + 1]

# NOTE: in theory we could easily precompute batch_indices. And positions is just position_ids
Expand Down Expand Up @@ -398,6 +599,20 @@ def _init_workspace(si: SequenceInfo) -> torch.Tensor:

return {"workspace_buffer": _init_workspace}

@classmethod
def host_prepare_for_forward(cls, sequence_info: SequenceInfo):
batch_info = sequence_info._input_buffer.get_host_view("batch_info")
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
# Call plan for generate-only batches.
if num_prefill == 0:
_GlobalFlashInferPlanner.plan_generate_only(
num_decode,
sequence_info._input_buffer.get_host_view("cu_num_pages")[: num_decode + 1],
sequence_info._input_buffer.get_host_view("cache_loc"),
sequence_info._input_buffer.get_host_view("last_page_len")[:num_decode],
)
return

@classmethod
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
# Sanity check: layout == "bsnd"
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,8 @@ def _build_input_ids(request) -> Tuple[List[int], List[int], bool]:
if new_tokens is not None:
self.cache_seq_interface.info.rescatter_input_ids(new_tokens.flatten())

self.cache_seq_interface.info.run_host_prepare_for_attention_forward()

self.iter_states["num_ctx_requests"] = num_ctx_requests
self.iter_states["num_ctx_tokens"] = num_ctx_tokens
# TODO: handle extend requests and draft requests for specdec
Expand Down
Loading