Skip to content

Commit 519b120

Browse files
committed
Replace torch topk with custom topk in context phase
Signed-off-by: yuhangh <58161490+heyuhhh@users.noreply.github.com>
1 parent b9a4dbb commit 519b120

File tree

2 files changed

+52
-12
lines changed

2 files changed

+52
-12
lines changed

tensorrt_llm/_torch/attention_backend/sparse/kernel.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,8 @@ def rocket_batch_to_flatten_kernel(
672672
token_mask = token_offsets < prefix_budget
673673

674674
# Load from prefix_indices
675-
prefix_indices = valid_idx_in_selected * num_kv_heads * prefix_budget + head_idx * prefix_budget + token_offsets
675+
flattened_idx = valid_idx_in_selected * num_kv_heads + head_idx
676+
prefix_indices = flattened_idx * prefix_budget + token_offsets
676677
prefix_values = tl.load(prefix_indices_ptr + prefix_indices,
677678
mask=token_mask,
678679
other=0)
@@ -717,26 +718,29 @@ def triton_rocket_batch_to_flatten(
717718
prefix_indices: torch.Tensor, input_lens: torch.Tensor,
718719
valid_seq_indices: torch.Tensor, output_offsets: torch.Tensor,
719720
batch_size: int, total_output_tokens: int, window_size: int,
720-
prompt_budget: int) -> tuple[torch.Tensor, torch.Tensor]:
721+
prompt_budget: int,
722+
num_kv_heads: int) -> tuple[torch.Tensor, torch.Tensor]:
721723
"""
722724
Flatten indices considering both valid and invalid batches.
723725
For valid sequences, combines prefix_indices with dynamically computed window indices.
724726
For invalid sequences, generates sequential indices.
725727
726728
Args:
727-
prefix_indices: Selected prefix indices [valid_batch_size, num_kv_heads, prefix_budget]
729+
prefix_indices: Selected prefix indices [valid_batch_size * num_kv_heads, prefix_budget]
728730
input_lens: Lengths for all sequences [batch_size]
729731
valid_seq_indices: Valid sequence indices [valid_batch_size]
730732
output_offsets: Offset for each batch [batch_size + 1]
731733
batch_size: Number of batches
732734
total_output_tokens: Total number of output tokens
733735
window_size: Size of sliding window at the end
734736
prompt_budget: Total number of tokens for valid sequences (prefix_budget + window_size)
737+
num_kv_heads: Number of KV heads
735738
736739
Returns:
737740
sparse_indices: Flattened sparse indices [num_kv_heads, total_output_tokens]
738741
"""
739-
valid_batch_size, num_kv_heads, prefix_budget = prefix_indices.shape
742+
total_tasks, prefix_budget = prefix_indices.shape
743+
valid_batch_size = total_tasks // num_kv_heads
740744

741745
# Create output tensor
742746
sparse_indices = torch.empty((num_kv_heads, total_output_tokens),

tensorrt_llm/_torch/attention_backend/sparse/rocket.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,24 @@ def __post_init__(self):
7575
dtype=torch.int32)
7676

7777
# Context length of RocketKV key for each valid sequence
78-
self.k_context_lens = torch.empty(
79-
self.max_num_sequences,
80-
device='cpu',
78+
self.k_context_lens_cuda = self.get_empty(
79+
self.cuda_graph_buffers,
80+
(self.max_num_sequences, ),
8181
dtype=torch.int32,
82+
cache_name="k_context_lens_cuda",
83+
capture_graph=capture_graph,
84+
)
85+
self.k_context_lens = torch.zeros_like(self.k_context_lens_cuda,
86+
device='cpu',
87+
dtype=torch.int32)
88+
89+
# Start index of RocketKV key for each valid sequence
90+
self.k_context_start_cuda = self.get_empty(
91+
None,
92+
(self.max_num_sequences, ),
93+
dtype=torch.int32,
94+
cache_name="k_context_start_cuda",
95+
capture_graph=capture_graph,
8296
)
8397

8498
# Cumulative context lengths for each sequence
@@ -231,6 +245,8 @@ def prepare(self):
231245
# Only consider sequences that are long enough for sparse kv indices prediction in context phase
232246
self.k_context_lens[:valid_batch_size] = self.prompt_lens_cpu[
233247
valid_seq_indices] - self.window_size
248+
self.k_context_lens_cuda[:valid_batch_size].copy_(
249+
self.k_context_lens[:valid_batch_size], non_blocking=True)
234250

235251
sparse_counts_ctx = torch.zeros(self.num_contexts,
236252
dtype=torch.int32,
@@ -399,12 +415,32 @@ def sparse_kv_predict(
399415
padding=self.kernel_size // 2,
400416
stride=1)
401417

402-
selected_prefix_indices = scores.topk(
403-
self.prompt_budget - self.window_size,
404-
dim=-1).indices.sort().values.to(torch.int32)
418+
# Use indexer topk prefill to select topk prefix indices
419+
total_tasks = metadata.valid_batch_size * self.num_kv_heads
420+
421+
selected_prefix_indices = torch.empty(
422+
(total_tasks, self.prompt_budget - self.window_size),
423+
device=qkv_input.device,
424+
dtype=torch.int32)
425+
426+
scores = scores.view(total_tasks, -1)
427+
428+
row_starts = metadata.k_context_start_cuda[:metadata.
429+
valid_batch_size].repeat_interleave(
430+
self.num_kv_heads)
431+
row_ends = metadata.k_context_lens_cuda[:metadata.
432+
valid_batch_size].repeat_interleave(
433+
self.num_kv_heads)
434+
torch.ops.trtllm.indexer_topk_prefill(
435+
scores, row_starts, row_ends, selected_prefix_indices,
436+
self.prompt_budget - self.window_size)
437+
438+
# Sort selected prefix indices to keep topk indices in ascending order
439+
selected_prefix_indices = torch.sort(selected_prefix_indices,
440+
dim=-1).values
405441
else:
406442
selected_prefix_indices = torch.empty(
407-
(0, self.num_kv_heads, self.prompt_budget - self.window_size),
443+
(0, self.prompt_budget - self.window_size),
408444
device=qkv_input.device,
409445
dtype=torch.int32)
410446

@@ -416,7 +452,7 @@ def sparse_kv_predict(
416452
selected_prefix_indices, metadata.prompt_lens_cuda,
417453
metadata.valid_seq_indices_cuda, sparse_kv_offsets,
418454
metadata.num_contexts, metadata.total_sparse_ctx_indices,
419-
self.window_size, self.prompt_budget)
455+
self.window_size, self.prompt_budget, self.num_kv_heads)
420456

421457
# Update KT cache
422458
kt_cache_tensor = metadata.kv_cache_manager.get_kt_buffers(

0 commit comments

Comments
 (0)