Skip to content

Commit e22d69c

Browse files
committed
Fix error from rebase
Signed-off-by: Hui Gao <huig@nvidia.com>
1 parent 8fc566d commit e22d69c

File tree

1 file changed

+18
-10
lines changed
  • tensorrt_llm/_torch/attention_backend/sparse

1 file changed

+18
-10
lines changed

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ def __post_init__(self):
287287
capture_graph = torch.cuda.is_current_stream_capturing()
288288

289289
self.indexer_k_cache_block_offsets = self.get_empty(
290+
self.cuda_graph_buffers,
290291
[self.max_num_sequences, self.kv_cache_manager.max_blocks_per_seq],
291292
cache_name="indexer_k_cache_block_offsets",
292293
dtype=torch.int32,
@@ -300,7 +301,7 @@ def __post_init__(self):
300301
# For mla_rope_append_paged_kv_assign_q
301302
if not self.enable_context_mla_with_cached_kv:
302303
self.ctx_cached_token_indptr = self.get_empty(
303-
(self.max_num_requests + 1, ),
304+
self.cuda_graph_buffers, (self.max_num_requests + 1, ),
304305
cache_name="ctx_cached_token_indptr",
305306
dtype=torch.int64,
306307
capture_graph=capture_graph)
@@ -309,7 +310,8 @@ def __post_init__(self):
309310
device='cpu',
310311
pin_memory=True,
311312
)
312-
self.ctx_kv_indptr = self.get_empty((self.max_num_requests + 1, ),
313+
self.ctx_kv_indptr = self.get_empty(self.cuda_graph_buffers,
314+
(self.max_num_requests + 1, ),
313315
cache_name="ctx_kv_indptr",
314316
dtype=torch.int64,
315317
capture_graph=capture_graph)
@@ -320,7 +322,7 @@ def __post_init__(self):
320322
)
321323
# New generation buffers for dsa
322324
self.gen_cached_token_indptr = self.get_empty(
323-
(self.max_num_requests + 1, ),
325+
self.cuda_graph_buffers, (self.max_num_requests + 1, ),
324326
cache_name="gen_cached_token_indptr",
325327
dtype=torch.int64,
326328
capture_graph=capture_graph)
@@ -329,7 +331,8 @@ def __post_init__(self):
329331
device='cpu',
330332
pin_memory=True,
331333
)
332-
self.gen_kv_indptr = self.get_empty((self.max_num_requests + 1, ),
334+
self.gen_kv_indptr = self.get_empty(self.cuda_graph_buffers,
335+
(self.max_num_requests + 1, ),
333336
cache_name="gen_kv_indptr",
334337
dtype=torch.int64,
335338
capture_graph=capture_graph)
@@ -340,7 +343,8 @@ def __post_init__(self):
340343
)
341344
# Indexer metadata
342345
# Separate slot mappings for non-interleaved layout (flat byte indices)
343-
self.slot_mapping_fp8 = self.get_empty((self.max_num_tokens, ),
346+
self.slot_mapping_fp8 = self.get_empty(self.cuda_graph_buffers,
347+
(self.max_num_tokens, ),
344348
cache_name="slot_mapping_fp8",
345349
dtype=torch.int64,
346350
capture_graph=capture_graph)
@@ -350,7 +354,7 @@ def __post_init__(self):
350354
pin_memory=True,
351355
)
352356
self.slot_mapping_scale = self.get_empty(
353-
(self.max_num_tokens, ),
357+
self.cuda_graph_buffers, (self.max_num_tokens, ),
354358
cache_name="slot_mapping_scale",
355359
dtype=torch.int64,
356360
capture_graph=capture_graph)
@@ -360,26 +364,30 @@ def __post_init__(self):
360364
pin_memory=True,
361365
)
362366
# Per-token request index buffer for topk_indices conversion
363-
self.req_idx_per_token = self.get_empty((self.max_num_tokens, ),
367+
self.req_idx_per_token = self.get_empty(self.cuda_graph_buffers,
368+
(self.max_num_tokens, ),
364369
cache_name="req_idx_per_token",
365370
dtype=torch.int32,
366371
capture_graph=capture_graph)
367372
# Block table for topk_indices conversion (shared for context and generation)
368373
self.block_table = self.get_empty(
374+
self.cuda_graph_buffers,
369375
(self.max_num_requests, self.kv_cache_manager.max_blocks_per_seq),
370376
cache_name="block_table",
371377
dtype=torch.int32,
372378
capture_graph=capture_graph)
373379
self.scheduler_metadata_buffer = self.get_empty(
374-
(self.num_sms + 1, 2),
380+
self.cuda_graph_buffers, (self.num_sms + 1, 2),
375381
cache_name="scheduler_metadata_buffer",
376382
dtype=torch.int32,
377383
capture_graph=capture_graph)
378-
self.cu_seqlen_ks = self.get_empty((self.max_num_tokens, ),
384+
self.cu_seqlen_ks = self.get_empty(self.cuda_graph_buffers,
385+
(self.max_num_tokens, ),
379386
cache_name="cu_seqlen_ks",
380387
dtype=torch.int32,
381388
capture_graph=capture_graph)
382-
self.cu_seqlen_ke = self.get_empty((self.max_num_tokens, ),
389+
self.cu_seqlen_ke = self.get_empty(self.cuda_graph_buffers,
390+
(self.max_num_tokens, ),
383391
cache_name="cu_seqlen_ke",
384392
dtype=torch.int32,
385393
capture_graph=capture_graph)

0 commit comments

Comments
 (0)