Skip to content

Commit 48a3f28

Browse files
committed
Fix error from rebase
Signed-off-by: Hui Gao <huig@nvidia.com>
1 parent 5686e23 commit 48a3f28

File tree

1 file changed

+18
-10
lines changed
  • tensorrt_llm/_torch/attention_backend/sparse

1 file changed

+18
-10
lines changed

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def __post_init__(self):
305305
capture_graph = torch.cuda.is_current_stream_capturing()
306306

307307
self.indexer_k_cache_block_offsets = self.get_empty(
308+
self.cuda_graph_buffers,
308309
[self.max_num_sequences, self.kv_cache_manager.max_blocks_per_seq],
309310
cache_name="indexer_k_cache_block_offsets",
310311
dtype=torch.int32,
@@ -318,7 +319,7 @@ def __post_init__(self):
318319
# For mla_rope_append_paged_kv_assign_q
319320
if not self.enable_context_mla_with_cached_kv:
320321
self.ctx_cached_token_indptr = self.get_empty(
321-
(self.max_num_requests + 1, ),
322+
self.cuda_graph_buffers, (self.max_num_requests + 1, ),
322323
cache_name="ctx_cached_token_indptr",
323324
dtype=torch.int64,
324325
capture_graph=capture_graph)
@@ -327,7 +328,8 @@ def __post_init__(self):
327328
device='cpu',
328329
pin_memory=True,
329330
)
330-
self.ctx_kv_indptr = self.get_empty((self.max_num_requests + 1, ),
331+
self.ctx_kv_indptr = self.get_empty(self.cuda_graph_buffers,
332+
(self.max_num_requests + 1, ),
331333
cache_name="ctx_kv_indptr",
332334
dtype=torch.int64,
333335
capture_graph=capture_graph)
@@ -338,7 +340,7 @@ def __post_init__(self):
338340
)
339341
# New generation buffers for dsa
340342
self.gen_cached_token_indptr = self.get_empty(
341-
(self.max_num_requests + 1, ),
343+
self.cuda_graph_buffers, (self.max_num_requests + 1, ),
342344
cache_name="gen_cached_token_indptr",
343345
dtype=torch.int64,
344346
capture_graph=capture_graph)
@@ -347,7 +349,8 @@ def __post_init__(self):
347349
device='cpu',
348350
pin_memory=True,
349351
)
350-
self.gen_kv_indptr = self.get_empty((self.max_num_requests + 1, ),
352+
self.gen_kv_indptr = self.get_empty(self.cuda_graph_buffers,
353+
(self.max_num_requests + 1, ),
351354
cache_name="gen_kv_indptr",
352355
dtype=torch.int64,
353356
capture_graph=capture_graph)
@@ -358,7 +361,8 @@ def __post_init__(self):
358361
)
359362
# Indexer metadata
360363
# Separate slot mappings for non-interleaved layout (flat byte indices)
361-
self.slot_mapping_fp8 = self.get_empty((self.max_num_tokens, ),
364+
self.slot_mapping_fp8 = self.get_empty(self.cuda_graph_buffers,
365+
(self.max_num_tokens, ),
362366
cache_name="slot_mapping_fp8",
363367
dtype=torch.int64,
364368
capture_graph=capture_graph)
@@ -368,7 +372,7 @@ def __post_init__(self):
368372
pin_memory=True,
369373
)
370374
self.slot_mapping_scale = self.get_empty(
371-
(self.max_num_tokens, ),
375+
self.cuda_graph_buffers, (self.max_num_tokens, ),
372376
cache_name="slot_mapping_scale",
373377
dtype=torch.int64,
374378
capture_graph=capture_graph)
@@ -378,26 +382,30 @@ def __post_init__(self):
378382
pin_memory=True,
379383
)
380384
# Per-token request index buffer for topk_indices conversion
381-
self.req_idx_per_token = self.get_empty((self.max_num_tokens, ),
385+
self.req_idx_per_token = self.get_empty(self.cuda_graph_buffers,
386+
(self.max_num_tokens, ),
382387
cache_name="req_idx_per_token",
383388
dtype=torch.int32,
384389
capture_graph=capture_graph)
385390
# Block table for topk_indices conversion (shared for context and generation)
386391
self.block_table = self.get_empty(
392+
self.cuda_graph_buffers,
387393
(self.max_num_requests, self.kv_cache_manager.max_blocks_per_seq),
388394
cache_name="block_table",
389395
dtype=torch.int32,
390396
capture_graph=capture_graph)
391397
self.scheduler_metadata_buffer = self.get_empty(
392-
(self.num_sms + 1, 2),
398+
self.cuda_graph_buffers, (self.num_sms + 1, 2),
393399
cache_name="scheduler_metadata_buffer",
394400
dtype=torch.int32,
395401
capture_graph=capture_graph)
396-
self.cu_seqlen_ks = self.get_empty((self.max_num_tokens, ),
402+
self.cu_seqlen_ks = self.get_empty(self.cuda_graph_buffers,
403+
(self.max_num_tokens, ),
397404
cache_name="cu_seqlen_ks",
398405
dtype=torch.int32,
399406
capture_graph=capture_graph)
400-
self.cu_seqlen_ke = self.get_empty((self.max_num_tokens, ),
407+
self.cu_seqlen_ke = self.get_empty(self.cuda_graph_buffers,
408+
(self.max_num_tokens, ),
401409
cache_name="cu_seqlen_ke",
402410
dtype=torch.int32,
403411
capture_graph=capture_graph)

0 commit comments

Comments
 (0)