Skip to content

Commit bf078a1

Browse files
committed
Add wrapper class of tensor
Support to share more buffers Signed-off-by: Hui Gao <huig@nvidia.com>
1 parent 6b9b73e commit bf078a1

File tree

8 files changed

+347
-123
lines changed

8 files changed

+347
-123
lines changed

tensorrt_llm/_torch/attention_backend/flashinfer.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,14 +124,20 @@ def positions(self) -> torch.Tensor:
124124

125125
def __post_init__(self) -> None:
126126
super().__post_init__()
127+
self._post_init_with_buffers(self.cuda_graph_buffers)
128+
129+
def _post_init_with_buffers(self, buffers) -> None:
130+
capture_graph = torch.cuda.is_current_stream_capturing()
127131

128132
if self.workspace_buffer is None:
129133
# Note: even though flashinfer only recommends 128 MB, we have to push it
130134
# a bit higher to cover all possible CUDA graph cases. If it's too small,
131135
# warmup will crash.
132-
self.workspace_buffer = torch.empty(320 * 1024 * 1024,
133-
dtype=torch.uint8,
134-
device="cuda")
136+
self.workspace_buffer = self.get_empty(
137+
buffers, (320 * 1024 * 1024, ),
138+
dtype=torch.uint8,
139+
cache_name="workspace_buffer",
140+
capture_graph=capture_graph)
135141

136142
self.paged_kv_indptr_decode = torch.empty((self.max_num_requests + 1, ),
137143
device='cuda',
@@ -163,9 +169,11 @@ def __post_init__(self) -> None:
163169

164170
if self.kv_cache_manager is not None:
165171
max_num_pages = self.kv_cache_manager.blocks_in_primary_pool
166-
self._paged_kv_indices = torch.empty((max_num_pages, ),
167-
device='cuda',
168-
dtype=torch.int)
172+
self._paged_kv_indices = self.get_empty(
173+
buffers, (max_num_pages, ),
174+
dtype=torch.int,
175+
cache_name="_paged_kv_indices",
176+
capture_graph=capture_graph)
169177

170178
def create_cuda_graph_metadata(self,
171179
max_batch_size: int,

tensorrt_llm/_torch/attention_backend/interface.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,47 @@ def update_for_spec_dec(self) -> None:
349349
Hook to be called during forward when using spec-dec one-model mode.
350350
"""
351351

352+
@staticmethod
353+
def get_empty(buffers,
354+
tensor_shape: list[int],
355+
dtype: torch.dtype,
356+
cache_name: str,
357+
capture_graph: bool = False) -> torch.Tensor:
358+
"""
359+
Finds a compatible, reusable buffer from a cache or creates a new one.
360+
361+
This function searches for a pre-allocated tensor (buffer) that can be
362+
reused for an operation involving a tensor with the shape of `tensor_shape`.
363+
364+
The compatibility rules are: The buffer's total elements must be >= tensor_shape's.
365+
366+
If a compatible buffer is found, it's returned immediately. Otherwise, a new
367+
buffer is allocated on the 'cuda' device with the give properties of 'tensor_shape' and 'dtype'.
368+
369+
Args:
370+
tensor_shape: The required shape.
371+
dtype: The required dtype.
372+
cache_name: The key for the specific list of buffers to search in.
373+
Returns:
374+
An existing compatible buffer or a newly created one.
375+
"""
376+
if buffers is None:
377+
return torch.zeros(tensor_shape, device='cuda', dtype=dtype)
378+
379+
return buffers.get_buffer(tensor_shape, dtype, cache_name,
380+
capture_graph)
381+
382+
@staticmethod
383+
def get_empty_like(buffers,
384+
like_tensor: torch.Tensor,
385+
cache_name: str,
386+
capture_graph: bool = False) -> torch.Tensor:
387+
return AttentionMetadata.get_empty(buffers,
388+
like_tensor.shape,
389+
dtype=like_tensor.dtype,
390+
cache_name=cache_name,
391+
capture_graph=capture_graph)
392+
352393

353394
class PositionalEmbedder(Protocol):
354395
"""

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 33 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -286,18 +286,11 @@ def __post_init__(self):
286286

287287
capture_graph = torch.cuda.is_current_stream_capturing()
288288

289-
def get_empty(tensor_shape: list[int], dtype: torch.dtype,
290-
cache_name: str) -> torch.Tensor:
291-
if self.cuda_graph_buffers is None:
292-
return torch.zeros(tensor_shape, device='cuda', dtype=dtype)
293-
return self.cuda_graph_buffers.get_buffer(tensor_shape, dtype,
294-
cache_name, capture_graph)
295-
296289
self.indexer_k_cache_block_offsets = get_empty(
297290
[self.max_num_sequences, self.kv_cache_manager.max_blocks_per_seq],
298291
cache_name="indexer_k_cache_block_offsets",
299292
dtype=torch.int32,
300-
)
293+
capture_graph=capture_graph)
301294
self.host_indexer_k_cache_block_offsets = torch.zeros_like(
302295
self.indexer_k_cache_block_offsets,
303296
device='cpu',
@@ -310,17 +303,16 @@ def get_empty(tensor_shape: list[int], dtype: torch.dtype,
310303
(self.max_num_requests + 1, ),
311304
cache_name="ctx_cached_token_indptr",
312305
dtype=torch.int64,
313-
)
306+
capture_graph=capture_graph)
314307
self.host_ctx_cached_token_indptr = torch.zeros_like(
315308
self.ctx_cached_token_indptr,
316309
device='cpu',
317310
pin_memory=True,
318311
)
319-
self.ctx_kv_indptr = get_empty(
320-
(self.max_num_requests + 1, ),
321-
cache_name="ctx_kv_indptr",
322-
dtype=torch.int64,
323-
)
312+
self.ctx_kv_indptr = get_empty((self.max_num_requests + 1, ),
313+
cache_name="ctx_kv_indptr",
314+
dtype=torch.int64,
315+
capture_graph=capture_graph)
324316
self.host_ctx_kv_indptr = torch.zeros_like(
325317
self.ctx_kv_indptr,
326318
device='cpu',
@@ -331,71 +323,65 @@ def get_empty(tensor_shape: list[int], dtype: torch.dtype,
331323
(self.max_num_requests + 1, ),
332324
cache_name="gen_cached_token_indptr",
333325
dtype=torch.int64,
334-
)
326+
capture_graph=capture_graph)
335327
self.host_gen_cached_token_indptr = torch.zeros_like(
336328
self.gen_cached_token_indptr,
337329
device='cpu',
338330
pin_memory=True,
339331
)
340-
self.gen_kv_indptr = get_empty(
341-
(self.max_num_requests + 1, ),
342-
cache_name="gen_kv_indptr",
343-
dtype=torch.int64,
344-
)
332+
self.gen_kv_indptr = get_empty((self.max_num_requests + 1, ),
333+
cache_name="gen_kv_indptr",
334+
dtype=torch.int64,
335+
capture_graph=capture_graph)
345336
self.host_gen_kv_indptr = torch.zeros_like(
346337
self.gen_kv_indptr,
347338
device='cpu',
348339
pin_memory=True,
349340
)
350341
# Indexer metadata
351342
# Separate slot mappings for non-interleaved layout (flat byte indices)
352-
self.slot_mapping_fp8 = get_empty(
353-
(self.max_num_tokens, ),
354-
cache_name="slot_mapping_fp8",
355-
dtype=torch.int64,
356-
)
343+
self.slot_mapping_fp8 = get_empty((self.max_num_tokens, ),
344+
cache_name="slot_mapping_fp8",
345+
dtype=torch.int64,
346+
capture_graph=capture_graph)
357347
self.host_slot_mapping_fp8 = torch.zeros_like(
358348
self.slot_mapping_fp8,
359349
device='cpu',
360350
pin_memory=True,
361351
)
362-
self.slot_mapping_scale = get_empty(
363-
(self.max_num_tokens, ),
364-
cache_name="slot_mapping_scale",
365-
dtype=torch.int64,
366-
)
352+
self.slot_mapping_scale = get_empty((self.max_num_tokens, ),
353+
cache_name="slot_mapping_scale",
354+
dtype=torch.int64,
355+
capture_graph=capture_graph)
367356
self.host_slot_mapping_scale = torch.zeros_like(
368357
self.slot_mapping_scale,
369358
device='cpu',
370359
pin_memory=True,
371360
)
372361
# Per-token request index buffer for topk_indices conversion
373-
self.req_idx_per_token = get_empty(
374-
(self.max_num_tokens, ),
375-
cache_name="req_idx_per_token",
376-
dtype=torch.int32,
377-
)
362+
self.req_idx_per_token = get_empty((self.max_num_tokens, ),
363+
cache_name="req_idx_per_token",
364+
dtype=torch.int32,
365+
capture_graph=capture_graph)
378366
# Block table for topk_indices conversion (shared for context and generation)
379367
self.block_table = get_empty(
380368
(self.max_num_requests, self.kv_cache_manager.max_blocks_per_seq),
381369
cache_name="block_table",
382370
dtype=torch.int32,
383-
)
371+
capture_graph=capture_graph)
384372
self.scheduler_metadata_buffer = get_empty(
385373
(self.num_sms + 1, 2),
386374
cache_name="scheduler_metadata_buffer",
387375
dtype=torch.int32,
388-
)
389-
self.cu_seqlen_ks = get_empty(
390-
(self.max_num_tokens, ),
391-
cache_name="cu_seqlen_ks",
392-
dtype=torch.int32,
393-
)
394-
self.cu_seqlen_ke = get_empty(
395-
(self.max_num_tokens, ),
396-
cache_name="cu_seqlen_ke",
397-
dtype=torch.int32,
398-
)
376+
capture_graph=capture_graph)
377+
self.cu_seqlen_ks = get_empty((self.max_num_tokens, ),
378+
cache_name="cu_seqlen_ks",
379+
dtype=torch.int32,
380+
capture_graph=capture_graph)
381+
self.cu_seqlen_ke = get_empty((self.max_num_tokens, ),
382+
cache_name="cu_seqlen_ke",
383+
dtype=torch.int32,
384+
capture_graph=capture_graph)
399385

400386
def prepare(self):
401387
super().prepare()

tensorrt_llm/_torch/attention_backend/sparse/rocket.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,17 @@ def __post_init__(self):
3535
if self.sparse_attention_config is None:
3636
raise ValueError("Sparse attention config is not set")
3737
self.prompt_budget = self.sparse_attention_config.prompt_budget
38-
self.kt_cache_block_offsets = torch.empty(
39-
[
38+
39+
capture_graph = torch.cuda.is_current_stream_capturing()
40+
self.kt_cache_block_offsets = self.get_empty(
41+
self.cuda_graph_buffers, [
4042
self.max_num_sequences,
4143
self.kv_cache_manager.max_kt_blocks_per_seq
4244
],
4345
dtype=torch.int32,
44-
device='cuda',
45-
)
46+
cache_name="kt_cache_block_offsets",
47+
capture_graph=capture_graph)
48+
4649
self.host_kt_cache_block_offsets = torch.zeros_like(
4750
self.kt_cache_block_offsets,
4851
device='cpu',

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 27 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -641,50 +641,20 @@ def _post_init_with_buffers(self, buffers) -> None:
641641

642642
capture_graph = torch.cuda.is_current_stream_capturing()
643643

644-
def get_empty(tensor_shape: list[int], dtype: torch.dtype,
645-
cache_name: str) -> torch.Tensor:
646-
"""
647-
Finds a compatible, reusable buffer from a cache or creates a new one.
648-
649-
This function searches for a pre-allocated tensor (buffer) that can be
650-
reused for an operation involving a tensor with the shape of `tensor_shape`.
651-
652-
The compatibility rules are: The buffer's total elements must be >= tensor_shape's.
653-
654-
If a compatible buffer is found, it's returned immediately. Otherwise, a new
655-
buffer is allocated on the 'cuda' device with the give properties of 'tensor_shape' and 'dtype'.
656-
657-
Args:
658-
tensor_shape: The required shape.
659-
dtype: The required dtype.
660-
cache_name: The key for the specific list of buffers to search in.
661-
Returns:
662-
An existing compatible buffer or a newly created one.
663-
"""
664-
if buffers is None:
665-
return torch.zeros(tensor_shape, device='cuda', dtype=dtype)
666-
667-
return buffers.get_buffer(tensor_shape, dtype, cache_name,
668-
capture_graph)
669-
670-
def get_empty_like(like_tensor: torch.Tensor,
671-
cache_name: str) -> torch.Tensor:
672-
return get_empty(like_tensor.shape,
673-
cache_name=cache_name,
674-
dtype=like_tensor.dtype)
675-
676-
self.prompt_lens_cuda = get_empty(
677-
(self.max_num_sequences, ),
678-
cache_name="prompt_lens_cuda",
679-
dtype=torch.int,
680-
)
644+
self.prompt_lens_cuda = self.get_empty(buffers,
645+
(self.max_num_sequences, ),
646+
cache_name="prompt_lens_cuda",
647+
dtype=torch.int,
648+
capture_graph=capture_graph)
681649
self.prompt_lens_cpu = torch.empty_like(
682650
self.prompt_lens_cuda,
683651
device='cpu',
684652
pin_memory=True,
685653
)
686-
self.kv_lens_cuda = get_empty_like(self.prompt_lens_cuda,
687-
cache_name="kv_lens_cuda")
654+
self.kv_lens_cuda = self.get_empty_like(buffers,
655+
self.prompt_lens_cuda,
656+
cache_name="kv_lens_cuda",
657+
capture_graph=capture_graph)
688658
self.kv_lens = torch.empty_like(self.kv_lens_cuda,
689659
device='cpu',
690660
pin_memory=True)
@@ -699,14 +669,14 @@ def get_empty_like(like_tensor: torch.Tensor,
699669
dtype=torch.int8,
700670
)
701671
if self.kv_cache_manager is not None:
702-
self.kv_cache_block_offsets = get_empty(
703-
[
672+
self.kv_cache_block_offsets = self.get_empty(
673+
buffers, [
704674
self.kv_cache_manager.num_pools, self.max_num_sequences, 2,
705675
self.kv_cache_manager.max_blocks_per_seq
706676
],
707677
cache_name="kv_cache_block_offsets",
708678
dtype=torch.int32,
709-
)
679+
capture_graph=capture_graph)
710680
self.host_kv_cache_block_offsets = torch.empty_like(
711681
self.kv_cache_block_offsets,
712682
device='cpu',
@@ -715,50 +685,50 @@ def get_empty_like(like_tensor: torch.Tensor,
715685
self.block_ids_per_seq = None
716686
self.kv_block_ids_per_seq = None
717687
if self.enable_flash_mla:
718-
self.block_ids_per_seq = get_empty(
719-
[
688+
self.block_ids_per_seq = self.get_empty(
689+
buffers, [
720690
self.kv_cache_manager.max_batch_size,
721691
self.kv_cache_manager.max_blocks_per_seq
722692
],
723693
cache_name="block_ids_per_seq",
724694
dtype=torch.int32,
725-
)
726-
self.kv_block_ids_per_seq = get_empty(
727-
[
695+
capture_graph=capture_graph)
696+
self.kv_block_ids_per_seq = self.get_empty(
697+
buffers, [
728698
self.kv_cache_manager.max_batch_size,
729699
self.kv_cache_manager.max_blocks_per_seq
730700
],
731701
cache_name="kv_block_ids_per_seq",
732702
dtype=torch.int32,
733-
)
703+
capture_graph=capture_graph)
734704
if self.enable_context_mla_with_cached_kv:
735705
# for kv cache reuse/chunked context in MLA
736-
self.ctx_cached_token_indptr = get_empty(
737-
(self.max_num_requests + 1, ),
706+
self.ctx_cached_token_indptr = self.get_empty(
707+
buffers, (self.max_num_requests + 1, ),
738708
cache_name="ctx_cached_token_indptr",
739709
dtype=torch.int64,
740-
)
710+
capture_graph=capture_graph)
741711
self.host_ctx_cached_token_indptr = torch.zeros_like(
742712
self.ctx_cached_token_indptr,
743713
device='cpu',
744714
pin_memory=True,
745715
)
746-
self.ctx_uncached_token_indptr = get_empty(
747-
(self.max_num_requests + 1, ),
716+
self.ctx_uncached_token_indptr = self.get_empty(
717+
buffers, (self.max_num_requests + 1, ),
748718
cache_name="ctx_uncached_token_indptr",
749719
dtype=torch.int64,
750-
)
720+
capture_graph=capture_graph)
751721
self.host_ctx_uncached_token_indptr = torch.zeros_like(
752722
self.ctx_uncached_token_indptr,
753723
device='cpu',
754724
pin_memory=True,
755725
)
756726
# context full seqlens include cached tokens and uncached tokens
757-
self.ctx_kv_indptr = get_empty(
758-
(self.max_num_requests + 1, ),
727+
self.ctx_kv_indptr = self.get_empty(
728+
buffers, (self.max_num_requests + 1, ),
759729
cache_name="ctx_kv_indptr",
760730
dtype=torch.int64,
761-
)
731+
capture_graph=capture_graph)
762732
self.host_ctx_kv_indptr = torch.zeros_like(
763733
self.ctx_kv_indptr,
764734
device='cpu',

0 commit comments

Comments
 (0)