diff --git a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py index aa32d6317ef..f88a16adf19 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py @@ -1329,7 +1329,6 @@ def _gather_k_cache_for_chunk( def sparse_attn_indexer( self, metadata: DSAtrtllmAttentionMetadata, - hidden_states: torch.Tensor, q_fp8: torch.Tensor, k_fp8: torch.Tensor, k_scale: torch.Tensor, @@ -1345,12 +1344,11 @@ def sparse_attn_indexer( has_prefill = num_contexts > 0 num_gen_tokens = num_tokens - num_ctx_tokens - topk_indices_buffer = torch.empty( - (hidden_states.shape[0], self.index_topk), - dtype=torch.int32, - device=hidden_states.device) + topk_indices_buffer = torch.empty((num_tokens, self.index_topk), + dtype=torch.int32, + device=q_fp8.device) if not use_custom_topk: - topk_indices_buffer[:hidden_states.shape[0]] = -1 + topk_indices_buffer[:num_tokens] = -1 if has_prefill and not metadata.skip_indexer_for_ctx_reqs: # Use chunked prefill to reduce memory footprint @@ -1526,11 +1524,10 @@ def _weight_scale(self, weights: torch.Tensor, weights = _scale(weights, q_scale, self.weight_scale_factor) return weights - def _qk_projection_and_rope(self, qr: torch.Tensor, indexer_k: torch.Tensor, + def _qk_projection_and_rope(self, qr: torch.Tensor, k: torch.Tensor, position_ids: torch.Tensor): """Project Q/K and apply RoPE""" q = self.wq_b(qr) - k = self.k_norm(indexer_k) q = q.view(-1, self.n_heads, self.head_dim) q_pe, q_nope = q.split([self.rope_dim, self.head_dim - self.rope_dim], dim=-1) @@ -1538,7 +1535,9 @@ def _qk_projection_and_rope(self, qr: torch.Tensor, indexer_k: torch.Tensor, dim=-1) q_pe, k_pe = self.rotary_emb(position_ids, [q_pe, k_pe.unsqueeze(1)]) k_pe = k_pe[:, 0, :] - return q_pe, q_nope, k_pe, k_nope + q = self._prep_q_or_k(q_pe, q_nope) + k = self._prep_q_or_k(k_pe, k_nope) + return q, k def _prep_q_or_k(self, qk_pe: torch.Tensor, qk_nope: torch.Tensor): """Concatenate, rotate, and FP8 quantize for Q or K""" @@ -1563,14 +1562,7 @@ def forward(self, qr: torch.Tensor, hidden_states: torch.Tensor, self.ln_events[1], self.aux_stream, ) - q_pe, q_nope, k_pe, k_nope = q_and_k - q, k = maybe_execute_in_parallel( - lambda: self._prep_q_or_k(q_pe, q_nope), - lambda: self._prep_q_or_k(k_pe, k_nope), - self.ln_events[0], - self.ln_events[1], - self.aux_stream, - ) + q, k = q_and_k q_fp8, q_scale = q k_fp8, k_scale = k q_fp8 = q_fp8.view(-1, self.n_heads, self.head_dim) @@ -1586,8 +1578,8 @@ def forward(self, qr: torch.Tensor, hidden_states: torch.Tensor, ) # Return topk indices buffer for sparse attention [num_tokens, index_topk] - return self.sparse_attn_indexer(metadata, hidden_states, q_fp8, k_fp8, - k_scale, weights) + return self.sparse_attn_indexer(metadata, q_fp8, k_fp8, k_scale, + weights) class DSATrtllmAttention(TrtllmAttention): diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 47e793e4816..2d68926ed4b 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -1363,7 +1363,6 @@ def forward_impl_with_dsa(self, position_ids: Optional[torch.Tensor], self.indexer.head_dim ], -1) - # TODO: possibly overlap/fuse q_a_rmsnorm + kv_a_rmsnorm + indexer.k_layernorm? q, compressed_kv = maybe_execute_in_parallel( lambda: self.q_a_layernorm(q), lambda: self.kv_a_layernorm(compressed_kv), @@ -1372,9 +1371,14 @@ def forward_impl_with_dsa(self, position_ids: Optional[torch.Tensor], self.aux_stream, ) qr = q - latent_cache = torch.concat([compressed_kv, k_pe], dim=-1) + latent_cache, indexer_k = maybe_execute_in_parallel( + lambda: torch.concat([compressed_kv, k_pe], dim=-1), + lambda: self.indexer.k_norm(indexer_k), + self.ln_events[0], + self.ln_events[1], + self.aux_stream, + ) - # TODO: fuse wq_b + (indexer) wlq here q = self.q_b_proj(q) # Indexer topk_indices = self.indexer( diff --git a/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py b/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py index 43872b7ced1..a1c23d0474e 100644 --- a/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py +++ b/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py @@ -1559,8 +1559,7 @@ def test_indexer_chunked_prefill(chunk_size, seq_lens_list, chunking_type): f"K[{chunk.k_token_start}:{chunk.k_token_end}] ({num_k} tokens)") indexer._update_k_cache(k_fp8, k_scale, metadata_chunked) - topk_indices_chunked = indexer.sparse_attn_indexer(metadata_chunked, - hidden_states, q_fp8, + topk_indices_chunked = indexer.sparse_attn_indexer(metadata_chunked, q_fp8, k_fp8, k_scale, weights) print(f"✓ Chunked execution completed, shape: {topk_indices_chunked.shape}") @@ -1592,8 +1591,8 @@ def test_indexer_chunked_prefill(chunk_size, seq_lens_list, chunking_type): indexer._update_k_cache(k_fp8, k_scale, metadata_baseline) topk_indices_baseline = indexer.sparse_attn_indexer(metadata_baseline, - hidden_states, q_fp8, - k_fp8, k_scale, weights) + q_fp8, k_fp8, k_scale, + weights) print( f"✓ Non-chunked execution completed, shape: {topk_indices_baseline.shape}" @@ -1866,7 +1865,6 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk, try: topk_indices_custom = indexer.sparse_attn_indexer(metadata_custom, - hidden_states, q_fp8, k_fp8, k_scale, @@ -1892,7 +1890,6 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk, Indexer.prepare(metadata_fallback) indexer._update_k_cache(k_fp8, k_scale, metadata_fallback) topk_indices_fallback = indexer.sparse_attn_indexer(metadata_fallback, - hidden_states, q_fp8, k_fp8, k_scale, @@ -1921,7 +1918,6 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk, try: topk_indices_skip = indexer.sparse_attn_indexer( metadata_skip, - hidden_states, q_fp8, k_fp8, k_scale, @@ -2035,7 +2031,6 @@ def test_indexer_prefill_chunked_custom_vs_fallback(batch_size, index_topk, try: topk_indices_custom = indexer.sparse_attn_indexer(metadata_custom, - hidden_states, q_fp8, k_fp8, k_scale, @@ -2055,7 +2050,6 @@ def test_indexer_prefill_chunked_custom_vs_fallback(batch_size, index_topk, Indexer.prepare(metadata_fallback) indexer._update_k_cache(k_fp8, k_scale, metadata_fallback) topk_indices_fallback = indexer.sparse_attn_indexer(metadata_fallback, - hidden_states, q_fp8, k_fp8, k_scale, @@ -2143,7 +2137,6 @@ def test_indexer_prefill_single_pass_custom_vs_fallback(batch_size, index_topk, try: topk_indices_custom = indexer.sparse_attn_indexer(metadata_custom, - hidden_states, q_fp8, k_fp8, k_scale, @@ -2166,7 +2159,6 @@ def test_indexer_prefill_single_pass_custom_vs_fallback(batch_size, index_topk, metadata_fallback.indexer_prefill_chunks = None topk_indices_fallback = indexer.sparse_attn_indexer(metadata_fallback, - hidden_states, q_fp8, k_fp8, k_scale, @@ -2191,7 +2183,6 @@ def test_indexer_prefill_single_pass_custom_vs_fallback(batch_size, index_topk, try: topk_indices_skip = indexer.sparse_attn_indexer(metadata_skip, - hidden_states, q_fp8, k_fp8, k_scale, @@ -2302,7 +2293,6 @@ def test_indexer_topk_multi_request_with_different_cache(enable_indexer_skip): # Test custom kernel topk_custom = indexer.sparse_attn_indexer(metadata, - hidden_states, q_fp8, k_fp8, k_scale, @@ -2311,7 +2301,6 @@ def test_indexer_topk_multi_request_with_different_cache(enable_indexer_skip): # Test fallback topk_fallback = indexer.sparse_attn_indexer(metadata, - hidden_states, q_fp8, k_fp8, k_scale, @@ -2337,7 +2326,6 @@ def test_indexer_topk_multi_request_with_different_cache(enable_indexer_skip): Indexer.prepare(metadata_skip) indexer._update_k_cache(k_fp8, k_scale, metadata_skip) topk_indices_skip = indexer.sparse_attn_indexer(metadata_skip, - hidden_states, q_fp8, k_fp8, k_scale,