Skip to content

Commit 5686e23

Browse files
committed
Address comments
Signed-off-by: Hui Gao <huig@nvidia.com>
1 parent 273d19f commit 5686e23

File tree

2 files changed

+56
-33
lines changed

2 files changed

+56
-33
lines changed

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def __post_init__(self):
304304

305305
capture_graph = torch.cuda.is_current_stream_capturing()
306306

307-
self.indexer_k_cache_block_offsets = get_empty(
307+
self.indexer_k_cache_block_offsets = self.get_empty(
308308
[self.max_num_sequences, self.kv_cache_manager.max_blocks_per_seq],
309309
cache_name="indexer_k_cache_block_offsets",
310310
dtype=torch.int32,
@@ -317,7 +317,7 @@ def __post_init__(self):
317317

318318
# For mla_rope_append_paged_kv_assign_q
319319
if not self.enable_context_mla_with_cached_kv:
320-
self.ctx_cached_token_indptr = get_empty(
320+
self.ctx_cached_token_indptr = self.get_empty(
321321
(self.max_num_requests + 1, ),
322322
cache_name="ctx_cached_token_indptr",
323323
dtype=torch.int64,
@@ -327,17 +327,17 @@ def __post_init__(self):
327327
device='cpu',
328328
pin_memory=True,
329329
)
330-
self.ctx_kv_indptr = get_empty((self.max_num_requests + 1, ),
331-
cache_name="ctx_kv_indptr",
332-
dtype=torch.int64,
333-
capture_graph=capture_graph)
330+
self.ctx_kv_indptr = self.get_empty((self.max_num_requests + 1, ),
331+
cache_name="ctx_kv_indptr",
332+
dtype=torch.int64,
333+
capture_graph=capture_graph)
334334
self.host_ctx_kv_indptr = torch.zeros_like(
335335
self.ctx_kv_indptr,
336336
device='cpu',
337337
pin_memory=True,
338338
)
339339
# New generation buffers for dsa
340-
self.gen_cached_token_indptr = get_empty(
340+
self.gen_cached_token_indptr = self.get_empty(
341341
(self.max_num_requests + 1, ),
342342
cache_name="gen_cached_token_indptr",
343343
dtype=torch.int64,
@@ -347,59 +347,60 @@ def __post_init__(self):
347347
device='cpu',
348348
pin_memory=True,
349349
)
350-
self.gen_kv_indptr = get_empty((self.max_num_requests + 1, ),
351-
cache_name="gen_kv_indptr",
352-
dtype=torch.int64,
353-
capture_graph=capture_graph)
350+
self.gen_kv_indptr = self.get_empty((self.max_num_requests + 1, ),
351+
cache_name="gen_kv_indptr",
352+
dtype=torch.int64,
353+
capture_graph=capture_graph)
354354
self.host_gen_kv_indptr = torch.zeros_like(
355355
self.gen_kv_indptr,
356356
device='cpu',
357357
pin_memory=True,
358358
)
359359
# Indexer metadata
360360
# Separate slot mappings for non-interleaved layout (flat byte indices)
361-
self.slot_mapping_fp8 = get_empty((self.max_num_tokens, ),
362-
cache_name="slot_mapping_fp8",
363-
dtype=torch.int64,
364-
capture_graph=capture_graph)
361+
self.slot_mapping_fp8 = self.get_empty((self.max_num_tokens, ),
362+
cache_name="slot_mapping_fp8",
363+
dtype=torch.int64,
364+
capture_graph=capture_graph)
365365
self.host_slot_mapping_fp8 = torch.zeros_like(
366366
self.slot_mapping_fp8,
367367
device='cpu',
368368
pin_memory=True,
369369
)
370-
self.slot_mapping_scale = get_empty((self.max_num_tokens, ),
371-
cache_name="slot_mapping_scale",
372-
dtype=torch.int64,
373-
capture_graph=capture_graph)
370+
self.slot_mapping_scale = self.get_empty(
371+
(self.max_num_tokens, ),
372+
cache_name="slot_mapping_scale",
373+
dtype=torch.int64,
374+
capture_graph=capture_graph)
374375
self.host_slot_mapping_scale = torch.zeros_like(
375376
self.slot_mapping_scale,
376377
device='cpu',
377378
pin_memory=True,
378379
)
379380
# Per-token request index buffer for topk_indices conversion
380-
self.req_idx_per_token = get_empty((self.max_num_tokens, ),
381-
cache_name="req_idx_per_token",
382-
dtype=torch.int32,
383-
capture_graph=capture_graph)
381+
self.req_idx_per_token = self.get_empty((self.max_num_tokens, ),
382+
cache_name="req_idx_per_token",
383+
dtype=torch.int32,
384+
capture_graph=capture_graph)
384385
# Block table for topk_indices conversion (shared for context and generation)
385-
self.block_table = get_empty(
386+
self.block_table = self.get_empty(
386387
(self.max_num_requests, self.kv_cache_manager.max_blocks_per_seq),
387388
cache_name="block_table",
388389
dtype=torch.int32,
389390
capture_graph=capture_graph)
390-
self.scheduler_metadata_buffer = get_empty(
391+
self.scheduler_metadata_buffer = self.get_empty(
391392
(self.num_sms + 1, 2),
392393
cache_name="scheduler_metadata_buffer",
393394
dtype=torch.int32,
394395
capture_graph=capture_graph)
395-
self.cu_seqlen_ks = get_empty((self.max_num_tokens, ),
396-
cache_name="cu_seqlen_ks",
397-
dtype=torch.int32,
398-
capture_graph=capture_graph)
399-
self.cu_seqlen_ke = get_empty((self.max_num_tokens, ),
400-
cache_name="cu_seqlen_ke",
401-
dtype=torch.int32,
402-
capture_graph=capture_graph)
396+
self.cu_seqlen_ks = self.get_empty((self.max_num_tokens, ),
397+
cache_name="cu_seqlen_ks",
398+
dtype=torch.int32,
399+
capture_graph=capture_graph)
400+
self.cu_seqlen_ke = self.get_empty((self.max_num_tokens, ),
401+
cache_name="cu_seqlen_ke",
402+
dtype=torch.int32,
403+
capture_graph=capture_graph)
403404

404405
def prepare(self):
405406
super().prepare()

tensorrt_llm/_torch/memory_buffer_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,28 @@
99
from tensorrt_llm.logger import logger
1010

1111

12+
def get_smallest_key_greater_than(ordered_dict, target_value):
13+
"""
14+
Return (k, ordered_dict[k]) where k is the smallest key with k >= target_value,
15+
or (None, None) if not found.
16+
"""
17+
min_key = min((k for k in ordered_dict.keys() if k >= target_value),
18+
default=None)
19+
return (min_key, ordered_dict[min_key]) if min_key is not None else (None,
20+
None)
21+
22+
23+
def get_biggest_key_smaller_than(ordered_dict, target_value):
24+
"""
25+
Return (k, ordered_dict[k]) where k is the largest key with k < target_value,
26+
or (None, None) if not found.
27+
"""
28+
max_key = max((k for k in ordered_dict.keys() if k < target_value),
29+
default=None)
30+
return (max_key, ordered_dict[max_key]) if max_key is not None else (None,
31+
None)
32+
33+
1234
def get_size_in_byte(target_shape: list[int], target_dtype: torch.dtype):
1335
return math.prod(target_shape) * target_dtype.itemsize
1436

0 commit comments

Comments
 (0)