Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions tensorrt_llm/_torch/attention_backend/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,22 @@ def positions(self) -> torch.Tensor:

def __post_init__(self) -> None:
super().__post_init__()
self._post_init_with_buffers(self.cuda_graph_buffers)

def _post_init_with_buffers(self, buffers) -> None:
capture_graph = torch.cuda.is_current_stream_capturing()

if self.workspace_buffer is None:
# Note: even though flashinfer only recommends 128 MB, we have to push it
# a bit higher to cover all possible CUDA graph cases. If it's too small,
# warmup will crash.
self.workspace_buffer = torch.empty(320 * 1024 * 1024,
dtype=torch.uint8,
device="cuda")
self.workspace_buffer = self.get_empty(
buffers,
(320 * 1024 * 1024, ),
dtype=torch.uint8,
cache_name="workspace_buffer",
capture_graph=capture_graph,
)

self.paged_kv_indptr_decode = torch.empty((self.max_num_requests + 1, ),
device='cuda',
Expand Down Expand Up @@ -163,9 +171,13 @@ def __post_init__(self) -> None:

if self.kv_cache_manager is not None:
max_num_pages = self.kv_cache_manager.blocks_in_primary_pool
self._paged_kv_indices = torch.empty((max_num_pages, ),
device='cuda',
dtype=torch.int)
self._paged_kv_indices = self.get_empty(
buffers,
(max_num_pages, ),
dtype=torch.int,
cache_name="_paged_kv_indices",
capture_graph=capture_graph,
)

def create_cuda_graph_metadata(self,
max_batch_size: int,
Expand Down
43 changes: 43 additions & 0 deletions tensorrt_llm/_torch/attention_backend/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,49 @@ def update_for_spec_dec(self) -> None:
Hook to be called during forward when using spec-dec one-model mode.
"""

@staticmethod
def get_empty(buffers,
tensor_shape: list[int],
dtype: torch.dtype,
cache_name: str,
capture_graph: bool = False) -> torch.Tensor:
"""
Finds a compatible, reusable buffer from a cache or creates a new one.

This function searches for a pre-allocated tensor (buffer) that can be
reused for an operation involving a tensor with the shape of `tensor_shape`.

The compatibility rules are: The buffer's total elements must be >= tensor_shape's.

If a compatible buffer is found, it's returned immediately. Otherwise, a new
buffer is allocated on the 'cuda' device with the give properties of 'tensor_shape' and 'dtype'.

Args:
tensor_shape: The required shape.
dtype: The required dtype.
cache_name: The key for the specific list of buffers to search in.
Returns:
An existing compatible buffer or a newly created one.
"""
if buffers is None:
return torch.zeros(tensor_shape, device='cuda', dtype=dtype)

return buffers.get_buffer(tensor_shape, dtype, cache_name,
capture_graph)

@staticmethod
def get_empty_like(buffers,
like_tensor: torch.Tensor,
cache_name: str,
capture_graph: bool = False) -> torch.Tensor:
return AttentionMetadata.get_empty(
buffers,
like_tensor.shape,
dtype=like_tensor.dtype,
cache_name=cache_name,
capture_graph=capture_graph,
)


class PositionalEmbedder(Protocol):
"""
Expand Down
55 changes: 36 additions & 19 deletions tensorrt_llm/_torch/attention_backend/sparse/dsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,17 +304,12 @@ def __post_init__(self):

capture_graph = torch.cuda.is_current_stream_capturing()

def get_empty(tensor_shape: list[int], dtype: torch.dtype,
cache_name: str) -> torch.Tensor:
if self.cuda_graph_buffers is None:
return torch.zeros(tensor_shape, device='cuda', dtype=dtype)
return self.cuda_graph_buffers.get_buffer(tensor_shape, dtype,
cache_name, capture_graph)

self.indexer_k_cache_block_offsets = get_empty(
self.indexer_k_cache_block_offsets = self.get_empty(
self.cuda_graph_buffers,
[self.max_num_sequences, self.kv_cache_manager.max_blocks_per_seq],
cache_name="indexer_k_cache_block_offsets",
dtype=torch.int32,
capture_graph=capture_graph,
)
self.host_indexer_k_cache_block_offsets = torch.zeros_like(
self.indexer_k_cache_block_offsets,
Expand All @@ -324,41 +319,49 @@ def get_empty(tensor_shape: list[int], dtype: torch.dtype,

# For mla_rope_append_paged_kv_assign_q
if not self.enable_context_mla_with_cached_kv:
self.ctx_cached_token_indptr = get_empty(
self.ctx_cached_token_indptr = self.get_empty(
self.cuda_graph_buffers,
(self.max_num_requests + 1, ),
cache_name="ctx_cached_token_indptr",
dtype=torch.int64,
capture_graph=capture_graph,
)
self.host_ctx_cached_token_indptr = torch.zeros_like(
self.ctx_cached_token_indptr,
device='cpu',
pin_memory=True,
)
self.ctx_kv_indptr = get_empty(
self.ctx_kv_indptr = self.get_empty(
self.cuda_graph_buffers,
(self.max_num_requests + 1, ),
cache_name="ctx_kv_indptr",
dtype=torch.int64,
capture_graph=capture_graph,
)
self.host_ctx_kv_indptr = torch.zeros_like(
self.ctx_kv_indptr,
device='cpu',
pin_memory=True,
)
# New generation buffers for dsa
self.gen_cached_token_indptr = get_empty(
self.gen_cached_token_indptr = self.get_empty(
self.cuda_graph_buffers,
(self.max_num_requests + 1, ),
cache_name="gen_cached_token_indptr",
dtype=torch.int64,
capture_graph=capture_graph,
)
self.host_gen_cached_token_indptr = torch.zeros_like(
self.gen_cached_token_indptr,
device='cpu',
pin_memory=True,
)
self.gen_kv_indptr = get_empty(
self.gen_kv_indptr = self.get_empty(
self.cuda_graph_buffers,
(self.max_num_requests + 1, ),
cache_name="gen_kv_indptr",
dtype=torch.int64,
capture_graph=capture_graph,
)
self.host_gen_kv_indptr = torch.zeros_like(
self.gen_kv_indptr,
Expand All @@ -367,52 +370,66 @@ def get_empty(tensor_shape: list[int], dtype: torch.dtype,
)
# Indexer metadata
# Separate slot mappings for non-interleaved layout (flat byte indices)
self.slot_mapping_fp8 = get_empty(
self.slot_mapping_fp8 = self.get_empty(
self.cuda_graph_buffers,
(self.max_num_tokens, ),
cache_name="slot_mapping_fp8",
dtype=torch.int64,
capture_graph=capture_graph,
)
self.host_slot_mapping_fp8 = torch.zeros_like(
self.slot_mapping_fp8,
device='cpu',
pin_memory=True,
)
self.slot_mapping_scale = get_empty(
self.slot_mapping_scale = self.get_empty(
self.cuda_graph_buffers,
(self.max_num_tokens, ),
cache_name="slot_mapping_scale",
dtype=torch.int64,
capture_graph=capture_graph,
)
self.host_slot_mapping_scale = torch.zeros_like(
self.slot_mapping_scale,
device='cpu',
pin_memory=True,
)
# Per-token request index buffer for topk_indices conversion
self.req_idx_per_token = get_empty(
self.req_idx_per_token = self.get_empty(
self.cuda_graph_buffers,
(self.max_num_tokens, ),
cache_name="req_idx_per_token",
dtype=torch.int32,
capture_graph=capture_graph,
)
# Block table for topk_indices conversion (shared for context and generation)
self.block_table = get_empty(
self.block_table = self.get_empty(
self.cuda_graph_buffers,
(self.max_num_requests, self.kv_cache_manager.max_blocks_per_seq),
cache_name="block_table",
dtype=torch.int32,
capture_graph=capture_graph,
)
self.scheduler_metadata_buffer = get_empty(
self.scheduler_metadata_buffer = self.get_empty(
self.cuda_graph_buffers,
(self.num_sms + 1, 2),
cache_name="scheduler_metadata_buffer",
dtype=torch.int32,
capture_graph=capture_graph,
)
self.cu_seqlen_ks = get_empty(
self.cu_seqlen_ks = self.get_empty(
self.cuda_graph_buffers,
(self.max_num_tokens, ),
cache_name="cu_seqlen_ks",
dtype=torch.int32,
capture_graph=capture_graph,
)
self.cu_seqlen_ke = get_empty(
self.cu_seqlen_ke = self.get_empty(
self.cuda_graph_buffers,
(self.max_num_tokens, ),
cache_name="cu_seqlen_ke",
dtype=torch.int32,
capture_graph=capture_graph,
)

def prepare(self):
Expand Down
9 changes: 7 additions & 2 deletions tensorrt_llm/_torch/attention_backend/sparse/rocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,19 @@ def __post_init__(self):
if self.sparse_attention_config is None:
raise ValueError("Sparse attention config is not set")
self.prompt_budget = self.sparse_attention_config.prompt_budget
self.kt_cache_block_offsets = torch.empty(

capture_graph = torch.cuda.is_current_stream_capturing()
self.kt_cache_block_offsets = self.get_empty(
self.cuda_graph_buffers,
[
self.max_num_sequences,
self.kv_cache_manager.max_kt_blocks_per_seq
],
dtype=torch.int32,
device='cuda',
cache_name="kt_cache_block_offsets",
capture_graph=capture_graph,
)

self.host_kt_cache_block_offsets = torch.zeros_like(
self.kt_cache_block_offsets,
device='cpu',
Expand Down
Loading
Loading