Skip to content

Commit 2c8b44c

Browse files
liji-nvLance Liao
authored andcommitted
[None][refactor] Move _update_k_cache into sparse_attn_indexer
Move _update_k_cache call to the top of sparse_attn_indexer so the k cache is populated right before prefill chunks gather from it. Remove pre_indexer (now redundant); forward() and forward_dsa_proj both call pre_indexer_proj directly. Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
1 parent 165d4b9 commit 2c8b44c

File tree

2 files changed

+5
-23
lines changed

2 files changed

+5
-23
lines changed

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,6 +1395,9 @@ def sparse_attn_indexer(
13951395
weights: torch.Tensor,
13961396
use_custom_topk: bool = True,
13971397
) -> torch.Tensor:
1398+
# Update the indexer k cache before prefill chunks gather from it.
1399+
self._update_k_cache(k_fp8, k_scale, metadata)
1400+
13981401
num_contexts = metadata.num_contexts
13991402
num_generations = metadata.num_generations
14001403
num_ctx_tokens = metadata.num_ctx_tokens
@@ -1669,24 +1672,6 @@ def _prep_q_or_k(self, qk_pe: torch.Tensor, qk_nope: torch.Tensor):
16691672
qk_pe, qk_nope, self.scale_fmt == "ue8m0")
16701673
return fp8_out, scale
16711674

1672-
@torch.inference_mode()
1673-
def pre_indexer(
1674-
self, qr: torch.Tensor, hidden_states: torch.Tensor,
1675-
metadata: DSAtrtllmAttentionMetadata, position_ids: torch.Tensor
1676-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1677-
"""Token-wise projections, FP8 quantize, weight scaling, and k cache update.
1678-
1679-
Runs the full indexer pre-computation including k cache update.
1680-
Used by the eager path (Indexer.forward) where everything runs
1681-
outside CUDA graph capture.
1682-
1683-
Returns (q_fp8, k_fp8, k_scale, weights).
1684-
"""
1685-
q_fp8, k_fp8, k_scale, weights = self.pre_indexer_proj(
1686-
qr, hidden_states, position_ids)
1687-
self._update_k_cache(k_fp8, k_scale, metadata)
1688-
return q_fp8, k_fp8, k_scale, weights
1689-
16901675
def pre_indexer_proj(
16911676
self, qr: torch.Tensor, hidden_states: torch.Tensor,
16921677
position_ids: torch.Tensor
@@ -1733,8 +1718,8 @@ def pre_indexer_proj(
17331718
def forward(self, qr: torch.Tensor, hidden_states: torch.Tensor,
17341719
metadata: DSAtrtllmAttentionMetadata,
17351720
position_ids: torch.Tensor):
1736-
q_fp8, k_fp8, k_scale, weights = self.pre_indexer(
1737-
qr, hidden_states, metadata, position_ids)
1721+
q_fp8, k_fp8, k_scale, weights = self.pre_indexer_proj(
1722+
qr, hidden_states, position_ids)
17381723

17391724
# Return topk indices buffer for sparse attention [num_tokens, index_topk]
17401725
return self.sparse_attn_indexer(metadata, hidden_states, q_fp8, k_fp8,

tensorrt_llm/_torch/modules/attention.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1786,9 +1786,6 @@ def forward_dsa_attn(
17861786
k_fp8 = k_fp8[:num_tokens, ...]
17871787
k_scale = k_scale[:num_tokens, ...]
17881788
weights = weights[:num_tokens, ...]
1789-
# Update the indexer k cache here (outside CUDA graph) because
1790-
# it accesses batch-specific metadata (slot_mapping_fp8/scale).
1791-
self.mqa.indexer._update_k_cache(k_fp8, k_scale, attn_metadata)
17921789
topk_indices = self.mqa.indexer.sparse_attn_indexer(
17931790
attn_metadata,
17941791
q, # only used for shape/device in buffer allocation

0 commit comments

Comments
 (0)