Skip to content

Commit 8f7d535

Browse files
committed
Address comments
Signed-off-by: Hui Gao <huig@nvidia.com>
1 parent 2191b8f commit 8f7d535

File tree

2 files changed

+56
-33
lines changed

2 files changed

+56
-33
lines changed

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def __post_init__(self):
286286

287287
capture_graph = torch.cuda.is_current_stream_capturing()
288288

289-
self.indexer_k_cache_block_offsets = get_empty(
289+
self.indexer_k_cache_block_offsets = self.get_empty(
290290
[self.max_num_sequences, self.kv_cache_manager.max_blocks_per_seq],
291291
cache_name="indexer_k_cache_block_offsets",
292292
dtype=torch.int32,
@@ -299,7 +299,7 @@ def __post_init__(self):
299299

300300
# For mla_rope_append_paged_kv_assign_q
301301
if not self.enable_context_mla_with_cached_kv:
302-
self.ctx_cached_token_indptr = get_empty(
302+
self.ctx_cached_token_indptr = self.get_empty(
303303
(self.max_num_requests + 1, ),
304304
cache_name="ctx_cached_token_indptr",
305305
dtype=torch.int64,
@@ -309,17 +309,17 @@ def __post_init__(self):
309309
device='cpu',
310310
pin_memory=True,
311311
)
312-
self.ctx_kv_indptr = get_empty((self.max_num_requests + 1, ),
313-
cache_name="ctx_kv_indptr",
314-
dtype=torch.int64,
315-
capture_graph=capture_graph)
312+
self.ctx_kv_indptr = self.get_empty((self.max_num_requests + 1, ),
313+
cache_name="ctx_kv_indptr",
314+
dtype=torch.int64,
315+
capture_graph=capture_graph)
316316
self.host_ctx_kv_indptr = torch.zeros_like(
317317
self.ctx_kv_indptr,
318318
device='cpu',
319319
pin_memory=True,
320320
)
321321
# New generation buffers for dsa
322-
self.gen_cached_token_indptr = get_empty(
322+
self.gen_cached_token_indptr = self.get_empty(
323323
(self.max_num_requests + 1, ),
324324
cache_name="gen_cached_token_indptr",
325325
dtype=torch.int64,
@@ -329,59 +329,60 @@ def __post_init__(self):
329329
device='cpu',
330330
pin_memory=True,
331331
)
332-
self.gen_kv_indptr = get_empty((self.max_num_requests + 1, ),
333-
cache_name="gen_kv_indptr",
334-
dtype=torch.int64,
335-
capture_graph=capture_graph)
332+
self.gen_kv_indptr = self.get_empty((self.max_num_requests + 1, ),
333+
cache_name="gen_kv_indptr",
334+
dtype=torch.int64,
335+
capture_graph=capture_graph)
336336
self.host_gen_kv_indptr = torch.zeros_like(
337337
self.gen_kv_indptr,
338338
device='cpu',
339339
pin_memory=True,
340340
)
341341
# Indexer metadata
342342
# Separate slot mappings for non-interleaved layout (flat byte indices)
343-
self.slot_mapping_fp8 = get_empty((self.max_num_tokens, ),
344-
cache_name="slot_mapping_fp8",
345-
dtype=torch.int64,
346-
capture_graph=capture_graph)
343+
self.slot_mapping_fp8 = self.get_empty((self.max_num_tokens, ),
344+
cache_name="slot_mapping_fp8",
345+
dtype=torch.int64,
346+
capture_graph=capture_graph)
347347
self.host_slot_mapping_fp8 = torch.zeros_like(
348348
self.slot_mapping_fp8,
349349
device='cpu',
350350
pin_memory=True,
351351
)
352-
self.slot_mapping_scale = get_empty((self.max_num_tokens, ),
353-
cache_name="slot_mapping_scale",
354-
dtype=torch.int64,
355-
capture_graph=capture_graph)
352+
self.slot_mapping_scale = self.get_empty(
353+
(self.max_num_tokens, ),
354+
cache_name="slot_mapping_scale",
355+
dtype=torch.int64,
356+
capture_graph=capture_graph)
356357
self.host_slot_mapping_scale = torch.zeros_like(
357358
self.slot_mapping_scale,
358359
device='cpu',
359360
pin_memory=True,
360361
)
361362
# Per-token request index buffer for topk_indices conversion
362-
self.req_idx_per_token = get_empty((self.max_num_tokens, ),
363-
cache_name="req_idx_per_token",
364-
dtype=torch.int32,
365-
capture_graph=capture_graph)
363+
self.req_idx_per_token = self.get_empty((self.max_num_tokens, ),
364+
cache_name="req_idx_per_token",
365+
dtype=torch.int32,
366+
capture_graph=capture_graph)
366367
# Block table for topk_indices conversion (shared for context and generation)
367-
self.block_table = get_empty(
368+
self.block_table = self.get_empty(
368369
(self.max_num_requests, self.kv_cache_manager.max_blocks_per_seq),
369370
cache_name="block_table",
370371
dtype=torch.int32,
371372
capture_graph=capture_graph)
372-
self.scheduler_metadata_buffer = get_empty(
373+
self.scheduler_metadata_buffer = self.get_empty(
373374
(self.num_sms + 1, 2),
374375
cache_name="scheduler_metadata_buffer",
375376
dtype=torch.int32,
376377
capture_graph=capture_graph)
377-
self.cu_seqlen_ks = get_empty((self.max_num_tokens, ),
378-
cache_name="cu_seqlen_ks",
379-
dtype=torch.int32,
380-
capture_graph=capture_graph)
381-
self.cu_seqlen_ke = get_empty((self.max_num_tokens, ),
382-
cache_name="cu_seqlen_ke",
383-
dtype=torch.int32,
384-
capture_graph=capture_graph)
378+
self.cu_seqlen_ks = self.get_empty((self.max_num_tokens, ),
379+
cache_name="cu_seqlen_ks",
380+
dtype=torch.int32,
381+
capture_graph=capture_graph)
382+
self.cu_seqlen_ke = self.get_empty((self.max_num_tokens, ),
383+
cache_name="cu_seqlen_ke",
384+
dtype=torch.int32,
385+
capture_graph=capture_graph)
385386

386387
def prepare(self):
387388
super().prepare()

tensorrt_llm/_torch/memory_buffer_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,28 @@
99
from tensorrt_llm.logger import logger
1010

1111

12+
def get_smallest_key_greater_than(ordered_dict, target_value):
13+
"""
14+
Return (k, ordered_dict[k]) where k is the smallest key with k >= target_value,
15+
or (None, None) if not found.
16+
"""
17+
min_key = min((k for k in ordered_dict.keys() if k >= target_value),
18+
default=None)
19+
return (min_key, ordered_dict[min_key]) if min_key is not None else (None,
20+
None)
21+
22+
23+
def get_biggest_key_smaller_than(ordered_dict, target_value):
24+
"""
25+
Return (k, ordered_dict[k]) where k is the largest key with k < target_value,
26+
or (None, None) if not found.
27+
"""
28+
max_key = max((k for k in ordered_dict.keys() if k < target_value),
29+
default=None)
30+
return (max_key, ordered_dict[max_key]) if max_key is not None else (None,
31+
None)
32+
33+
1234
def get_size_in_byte(target_shape: list[int], target_dtype: torch.dtype):
1335
return math.prod(target_shape) * target_dtype.itemsize
1436

0 commit comments

Comments
 (0)