Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 11 additions & 19 deletions tensorrt_llm/_torch/attention_backend/sparse/dsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -1329,7 +1329,6 @@ def _gather_k_cache_for_chunk(
def sparse_attn_indexer(
self,
metadata: DSAtrtllmAttentionMetadata,
hidden_states: torch.Tensor,
q_fp8: torch.Tensor,
k_fp8: torch.Tensor,
k_scale: torch.Tensor,
Expand All @@ -1345,12 +1344,11 @@ def sparse_attn_indexer(
has_prefill = num_contexts > 0
num_gen_tokens = num_tokens - num_ctx_tokens

topk_indices_buffer = torch.empty(
(hidden_states.shape[0], self.index_topk),
dtype=torch.int32,
device=hidden_states.device)
topk_indices_buffer = torch.empty((num_tokens, self.index_topk),
dtype=torch.int32,
device=q_fp8.device)
if not use_custom_topk:
topk_indices_buffer[:hidden_states.shape[0]] = -1
topk_indices_buffer[:num_tokens] = -1

if has_prefill and not metadata.skip_indexer_for_ctx_reqs:
# Use chunked prefill to reduce memory footprint
Expand Down Expand Up @@ -1526,19 +1524,20 @@ def _weight_scale(self, weights: torch.Tensor,
weights = _scale(weights, q_scale, self.weight_scale_factor)
return weights

def _qk_projection_and_rope(self, qr: torch.Tensor, indexer_k: torch.Tensor,
def _qk_projection_and_rope(self, qr: torch.Tensor, k: torch.Tensor,
position_ids: torch.Tensor):
"""Project Q/K and apply RoPE"""
q = self.wq_b(qr)
k = self.k_norm(indexer_k)
q = q.view(-1, self.n_heads, self.head_dim)
q_pe, q_nope = q.split([self.rope_dim, self.head_dim - self.rope_dim],
dim=-1)
k_pe, k_nope = k.split([self.rope_dim, self.head_dim - self.rope_dim],
dim=-1)
q_pe, k_pe = self.rotary_emb(position_ids, [q_pe, k_pe.unsqueeze(1)])
k_pe = k_pe[:, 0, :]
return q_pe, q_nope, k_pe, k_nope
q = self._prep_q_or_k(q_pe, q_nope)
k = self._prep_q_or_k(k_pe, k_nope)
return q, k

def _prep_q_or_k(self, qk_pe: torch.Tensor, qk_nope: torch.Tensor):
"""Concatenate, rotate, and FP8 quantize for Q or K"""
Expand All @@ -1563,14 +1562,7 @@ def forward(self, qr: torch.Tensor, hidden_states: torch.Tensor,
self.ln_events[1],
self.aux_stream,
)
q_pe, q_nope, k_pe, k_nope = q_and_k
q, k = maybe_execute_in_parallel(
lambda: self._prep_q_or_k(q_pe, q_nope),
lambda: self._prep_q_or_k(k_pe, k_nope),
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)
q, k = q_and_k
q_fp8, q_scale = q
k_fp8, k_scale = k
q_fp8 = q_fp8.view(-1, self.n_heads, self.head_dim)
Expand All @@ -1586,8 +1578,8 @@ def forward(self, qr: torch.Tensor, hidden_states: torch.Tensor,
)

# Return topk indices buffer for sparse attention [num_tokens, index_topk]
return self.sparse_attn_indexer(metadata, hidden_states, q_fp8, k_fp8,
k_scale, weights)
return self.sparse_attn_indexer(metadata, q_fp8, k_fp8, k_scale,
weights)


class DSATrtllmAttention(TrtllmAttention):
Expand Down
10 changes: 7 additions & 3 deletions tensorrt_llm/_torch/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,7 +1363,6 @@ def forward_impl_with_dsa(self, position_ids: Optional[torch.Tensor],
self.indexer.head_dim
], -1)

# TODO: possibly overlap/fuse q_a_rmsnorm + kv_a_rmsnorm + indexer.k_layernorm?
q, compressed_kv = maybe_execute_in_parallel(
lambda: self.q_a_layernorm(q),
lambda: self.kv_a_layernorm(compressed_kv),
Expand All @@ -1372,9 +1371,14 @@ def forward_impl_with_dsa(self, position_ids: Optional[torch.Tensor],
self.aux_stream,
)
qr = q
latent_cache = torch.concat([compressed_kv, k_pe], dim=-1)
latent_cache, indexer_k = maybe_execute_in_parallel(
lambda: torch.concat([compressed_kv, k_pe], dim=-1),
lambda: self.indexer.k_norm(indexer_k),
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)

# TODO: fuse wq_b + (indexer) wlq here
q = self.q_b_proj(q)
# Indexer
topk_indices = self.indexer(
Expand Down
18 changes: 3 additions & 15 deletions tests/unittest/_torch/attention/sparse/test_dsa_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1559,8 +1559,7 @@ def test_indexer_chunked_prefill(chunk_size, seq_lens_list, chunking_type):
f"K[{chunk.k_token_start}:{chunk.k_token_end}] ({num_k} tokens)")

indexer._update_k_cache(k_fp8, k_scale, metadata_chunked)
topk_indices_chunked = indexer.sparse_attn_indexer(metadata_chunked,
hidden_states, q_fp8,
topk_indices_chunked = indexer.sparse_attn_indexer(metadata_chunked, q_fp8,
k_fp8, k_scale, weights)

print(f"✓ Chunked execution completed, shape: {topk_indices_chunked.shape}")
Expand Down Expand Up @@ -1592,8 +1591,8 @@ def test_indexer_chunked_prefill(chunk_size, seq_lens_list, chunking_type):

indexer._update_k_cache(k_fp8, k_scale, metadata_baseline)
topk_indices_baseline = indexer.sparse_attn_indexer(metadata_baseline,
hidden_states, q_fp8,
k_fp8, k_scale, weights)
q_fp8, k_fp8, k_scale,
weights)

print(
f"✓ Non-chunked execution completed, shape: {topk_indices_baseline.shape}"
Expand Down Expand Up @@ -1866,7 +1865,6 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk,

try:
topk_indices_custom = indexer.sparse_attn_indexer(metadata_custom,
hidden_states,
q_fp8,
k_fp8,
k_scale,
Expand All @@ -1892,7 +1890,6 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk,
Indexer.prepare(metadata_fallback)
indexer._update_k_cache(k_fp8, k_scale, metadata_fallback)
topk_indices_fallback = indexer.sparse_attn_indexer(metadata_fallback,
hidden_states,
q_fp8,
k_fp8,
k_scale,
Expand Down Expand Up @@ -1921,7 +1918,6 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk,
try:
topk_indices_skip = indexer.sparse_attn_indexer(
metadata_skip,
hidden_states,
q_fp8,
k_fp8,
k_scale,
Expand Down Expand Up @@ -2035,7 +2031,6 @@ def test_indexer_prefill_chunked_custom_vs_fallback(batch_size, index_topk,

try:
topk_indices_custom = indexer.sparse_attn_indexer(metadata_custom,
hidden_states,
q_fp8,
k_fp8,
k_scale,
Expand All @@ -2055,7 +2050,6 @@ def test_indexer_prefill_chunked_custom_vs_fallback(batch_size, index_topk,
Indexer.prepare(metadata_fallback)
indexer._update_k_cache(k_fp8, k_scale, metadata_fallback)
topk_indices_fallback = indexer.sparse_attn_indexer(metadata_fallback,
hidden_states,
q_fp8,
k_fp8,
k_scale,
Expand Down Expand Up @@ -2143,7 +2137,6 @@ def test_indexer_prefill_single_pass_custom_vs_fallback(batch_size, index_topk,

try:
topk_indices_custom = indexer.sparse_attn_indexer(metadata_custom,
hidden_states,
q_fp8,
k_fp8,
k_scale,
Expand All @@ -2166,7 +2159,6 @@ def test_indexer_prefill_single_pass_custom_vs_fallback(batch_size, index_topk,
metadata_fallback.indexer_prefill_chunks = None

topk_indices_fallback = indexer.sparse_attn_indexer(metadata_fallback,
hidden_states,
q_fp8,
k_fp8,
k_scale,
Expand All @@ -2191,7 +2183,6 @@ def test_indexer_prefill_single_pass_custom_vs_fallback(batch_size, index_topk,

try:
topk_indices_skip = indexer.sparse_attn_indexer(metadata_skip,
hidden_states,
q_fp8,
k_fp8,
k_scale,
Expand Down Expand Up @@ -2302,7 +2293,6 @@ def test_indexer_topk_multi_request_with_different_cache(enable_indexer_skip):

# Test custom kernel
topk_custom = indexer.sparse_attn_indexer(metadata,
hidden_states,
q_fp8,
k_fp8,
k_scale,
Expand All @@ -2311,7 +2301,6 @@ def test_indexer_topk_multi_request_with_different_cache(enable_indexer_skip):

# Test fallback
topk_fallback = indexer.sparse_attn_indexer(metadata,
hidden_states,
q_fp8,
k_fp8,
k_scale,
Expand All @@ -2337,7 +2326,6 @@ def test_indexer_topk_multi_request_with_different_cache(enable_indexer_skip):
Indexer.prepare(metadata_skip)
indexer._update_k_cache(k_fp8, k_scale, metadata_skip)
topk_indices_skip = indexer.sparse_attn_indexer(metadata_skip,
hidden_states,
q_fp8,
k_fp8,
k_scale,
Expand Down