Skip to content

Commit b9ae472

Browse files
committed
Address comment
Signed-off-by: Hui Gao <huig@nvidia.com>
1 parent 48a3f28 commit b9ae472

File tree

2 files changed

+62
-100
lines changed

2 files changed

+62
-100
lines changed

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 62 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,8 @@ def __post_init__(self):
309309
[self.max_num_sequences, self.kv_cache_manager.max_blocks_per_seq],
310310
cache_name="indexer_k_cache_block_offsets",
311311
dtype=torch.int32,
312-
capture_graph=capture_graph)
312+
capture_graph=capture_graph,
313+
)
313314
self.host_indexer_k_cache_block_offsets = torch.zeros_like(
314315
self.indexer_k_cache_block_offsets,
315316
device='cpu',
@@ -319,96 +320,117 @@ def __post_init__(self):
319320
# For mla_rope_append_paged_kv_assign_q
320321
if not self.enable_context_mla_with_cached_kv:
321322
self.ctx_cached_token_indptr = self.get_empty(
322-
self.cuda_graph_buffers, (self.max_num_requests + 1, ),
323+
self.cuda_graph_buffers,
324+
(self.max_num_requests + 1, ),
323325
cache_name="ctx_cached_token_indptr",
324326
dtype=torch.int64,
325-
capture_graph=capture_graph)
327+
capture_graph=capture_graph,
328+
)
326329
self.host_ctx_cached_token_indptr = torch.zeros_like(
327330
self.ctx_cached_token_indptr,
328331
device='cpu',
329332
pin_memory=True,
330333
)
331-
self.ctx_kv_indptr = self.get_empty(self.cuda_graph_buffers,
332-
(self.max_num_requests + 1, ),
333-
cache_name="ctx_kv_indptr",
334-
dtype=torch.int64,
335-
capture_graph=capture_graph)
334+
self.ctx_kv_indptr = self.get_empty(
335+
self.cuda_graph_buffers,
336+
(self.max_num_requests + 1, ),
337+
cache_name="ctx_kv_indptr",
338+
dtype=torch.int64,
339+
capture_graph=capture_graph,
340+
)
336341
self.host_ctx_kv_indptr = torch.zeros_like(
337342
self.ctx_kv_indptr,
338343
device='cpu',
339344
pin_memory=True,
340345
)
341346
# New generation buffers for dsa
342347
self.gen_cached_token_indptr = self.get_empty(
343-
self.cuda_graph_buffers, (self.max_num_requests + 1, ),
348+
self.cuda_graph_buffers,
349+
(self.max_num_requests + 1, ),
344350
cache_name="gen_cached_token_indptr",
345351
dtype=torch.int64,
346-
capture_graph=capture_graph)
352+
capture_graph=capture_graph,
353+
)
347354
self.host_gen_cached_token_indptr = torch.zeros_like(
348355
self.gen_cached_token_indptr,
349356
device='cpu',
350357
pin_memory=True,
351358
)
352-
self.gen_kv_indptr = self.get_empty(self.cuda_graph_buffers,
353-
(self.max_num_requests + 1, ),
354-
cache_name="gen_kv_indptr",
355-
dtype=torch.int64,
356-
capture_graph=capture_graph)
359+
self.gen_kv_indptr = self.get_empty(
360+
self.cuda_graph_buffers,
361+
(self.max_num_requests + 1, ),
362+
cache_name="gen_kv_indptr",
363+
dtype=torch.int64,
364+
capture_graph=capture_graph,
365+
)
357366
self.host_gen_kv_indptr = torch.zeros_like(
358367
self.gen_kv_indptr,
359368
device='cpu',
360369
pin_memory=True,
361370
)
362371
# Indexer metadata
363372
# Separate slot mappings for non-interleaved layout (flat byte indices)
364-
self.slot_mapping_fp8 = self.get_empty(self.cuda_graph_buffers,
365-
(self.max_num_tokens, ),
366-
cache_name="slot_mapping_fp8",
367-
dtype=torch.int64,
368-
capture_graph=capture_graph)
373+
self.slot_mapping_fp8 = self.get_empty(
374+
self.cuda_graph_buffers,
375+
(self.max_num_tokens, ),
376+
cache_name="slot_mapping_fp8",
377+
dtype=torch.int64,
378+
capture_graph=capture_graph,
379+
)
369380
self.host_slot_mapping_fp8 = torch.zeros_like(
370381
self.slot_mapping_fp8,
371382
device='cpu',
372383
pin_memory=True,
373384
)
374385
self.slot_mapping_scale = self.get_empty(
375-
self.cuda_graph_buffers, (self.max_num_tokens, ),
386+
self.cuda_graph_buffers,
387+
(self.max_num_tokens, ),
376388
cache_name="slot_mapping_scale",
377389
dtype=torch.int64,
378-
capture_graph=capture_graph)
390+
capture_graph=capture_graph,
391+
)
379392
self.host_slot_mapping_scale = torch.zeros_like(
380393
self.slot_mapping_scale,
381394
device='cpu',
382395
pin_memory=True,
383396
)
384397
# Per-token request index buffer for topk_indices conversion
385-
self.req_idx_per_token = self.get_empty(self.cuda_graph_buffers,
386-
(self.max_num_tokens, ),
387-
cache_name="req_idx_per_token",
388-
dtype=torch.int32,
389-
capture_graph=capture_graph)
398+
self.req_idx_per_token = self.get_empty(
399+
self.cuda_graph_buffers,
400+
(self.max_num_tokens, ),
401+
cache_name="req_idx_per_token",
402+
dtype=torch.int32,
403+
capture_graph=capture_graph,
404+
)
390405
# Block table for topk_indices conversion (shared for context and generation)
391406
self.block_table = self.get_empty(
392407
self.cuda_graph_buffers,
393408
(self.max_num_requests, self.kv_cache_manager.max_blocks_per_seq),
394409
cache_name="block_table",
395410
dtype=torch.int32,
396-
capture_graph=capture_graph)
411+
capture_graph=capture_graph,
412+
)
397413
self.scheduler_metadata_buffer = self.get_empty(
398-
self.cuda_graph_buffers, (self.num_sms + 1, 2),
414+
self.cuda_graph_buffers,
415+
(self.num_sms + 1, 2),
399416
cache_name="scheduler_metadata_buffer",
400417
dtype=torch.int32,
401-
capture_graph=capture_graph)
402-
self.cu_seqlen_ks = self.get_empty(self.cuda_graph_buffers,
403-
(self.max_num_tokens, ),
404-
cache_name="cu_seqlen_ks",
405-
dtype=torch.int32,
406-
capture_graph=capture_graph)
407-
self.cu_seqlen_ke = self.get_empty(self.cuda_graph_buffers,
408-
(self.max_num_tokens, ),
409-
cache_name="cu_seqlen_ke",
410-
dtype=torch.int32,
411-
capture_graph=capture_graph)
418+
capture_graph=capture_graph,
419+
)
420+
self.cu_seqlen_ks = self.get_empty(
421+
self.cuda_graph_buffers,
422+
(self.max_num_tokens, ),
423+
cache_name="cu_seqlen_ks",
424+
dtype=torch.int32,
425+
capture_graph=capture_graph,
426+
)
427+
self.cu_seqlen_ke = self.get_empty(
428+
self.cuda_graph_buffers,
429+
(self.max_num_tokens, ),
430+
cache_name="cu_seqlen_ke",
431+
dtype=torch.int32,
432+
capture_graph=capture_graph,
433+
)
412434

413435
def prepare(self):
414436
super().prepare()

tensorrt_llm/_torch/memory_buffer_utils.py

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import contextlib
22
import math
3-
from collections import OrderedDict
43
from dataclasses import dataclass
54
from typing import Optional
65

@@ -9,28 +8,6 @@
98
from tensorrt_llm.logger import logger
109

1110

12-
def get_smallest_key_greater_than(ordered_dict, target_value):
13-
"""
14-
Return (k, ordered_dict[k]) where k is the smallest key with k >= target_value,
15-
or (None, None) if not found.
16-
"""
17-
min_key = min((k for k in ordered_dict.keys() if k >= target_value),
18-
default=None)
19-
return (min_key, ordered_dict[min_key]) if min_key is not None else (None,
20-
None)
21-
22-
23-
def get_biggest_key_smaller_than(ordered_dict, target_value):
24-
"""
25-
Return (k, ordered_dict[k]) where k is the largest key with k < target_value,
26-
or (None, None) if not found.
27-
"""
28-
max_key = max((k for k in ordered_dict.keys() if k < target_value),
29-
default=None)
30-
return (max_key, ordered_dict[max_key]) if max_key is not None else (None,
31-
None)
32-
33-
3411
def get_size_in_byte(target_shape: list[int], target_dtype: torch.dtype):
3512
return math.prod(target_shape) * target_dtype.itemsize
3613

@@ -57,7 +34,6 @@ class Buffers:
5734

5835
def __init__(self):
5936
self.buffers: dict[str, list[BufferBlock]] = {}
60-
self.managed_buffers = OrderedDict()
6137
self.max_buffer_concurrency = 0
6238

6339
@staticmethod
@@ -74,48 +50,12 @@ def _view_as(buffer: torch.Tensor, target_shape: list[int],
7450
return buffer[:required_memory_size].view(target_dtype).view(
7551
target_shape)
7652

77-
def _get_managed_buffer(self, required_memory_size: int):
78-
size, buffer = get_smallest_key_greater_than(self.managed_buffers,
79-
required_memory_size)
80-
81-
if size is not None and buffer is not None:
82-
return buffer
83-
84-
size_1, buffer_1 = get_biggest_key_smaller_than(self.managed_buffers,
85-
required_memory_size)
86-
if size_1 is not None and buffer is not None:
87-
del self.managed_buffers[size_1]
88-
89-
new_buffer_tensor = None
90-
try:
91-
with torch.cuda.memory.use_mem_pool(get_shared_pool()):
92-
new_buffer_tensor = torch.zeros((required_memory_size, ),
93-
device='cuda',
94-
dtype=torch.uint8)
95-
except Exception as ex:
96-
# Need to check if this is an OOM exception
97-
logger.debug(
98-
f"Exception happened to create tensor from given memory pool: {str(ex)}"
99-
)
100-
# if exception happens during allocating memory from shared pool, retry
101-
# to allocate from default pool
102-
new_buffer_tensor = torch.zeros((required_memory_size, ),
103-
device='cuda',
104-
dtype=torch.uint8)
105-
106-
self.managed_buffers[required_memory_size] = new_buffer_tensor
107-
108-
return new_buffer_tensor
109-
11053
def get_buffer(self, tensor_shape: list[int], dtype: torch.dtype,
11154
buffer_name: str, reserve_buffer: bool):
11255

11356
# all buffers are allocated with 1 byte element size
11457
required_memory_size = math.prod(tensor_shape) * dtype.itemsize
11558

116-
if buffer_name is None or len(buffer_name) == 0:
117-
return _get_managed_buffer(required_memory_size)
118-
11959
candidate_blocks = self.buffers.get(buffer_name, [])
12060

12161
# Find the best-fit available buffer.

0 commit comments

Comments
 (0)