Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions tensorrt_llm/_torch/attention_backend/sparse/dsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,12 +1049,11 @@ def prepare(metadata: DSAtrtllmAttentionMetadata):
# Indexer should just process the current MLA chunk as a single chunk
has_mla_chunked_prefill = (
metadata.enable_context_mla_with_cached_kv
and host_cached_tokens.sum().item() > 0
and metadata.runtime_features.chunked_prefill)

if has_mla_chunked_prefill:
# The MLA has already split the sequence, here just process what's given (as a single chunk)
# Cached token info is derived from metadata.host_ctx_cached_token_indptr in prepare_one_prefill_chunk
# MLA chunked prefill is active - use single-chunk pattern for
# indexer prefill chunks.
chunk_specs = [(i, 0, host_seq_lens[i].item(),
host_seq_lens[:i].sum().item() if i > 0 else 0)
for i in range(num_contexts)]
Expand All @@ -1065,7 +1064,8 @@ def prepare(metadata: DSAtrtllmAttentionMetadata):
)
]
else:
# Normal mode: use indexer's own chunking logic to prevent L^2 complexity when long-sequence is used.
# Use indexer's own chunking logic to prevent L^2 complexity of indexer MQA logits computation for long sequences.
# This is only used when MLA chunked prefill is not enabled.
chunk_groups = split_prefill_chunks(
host_seq_lens,
metadata.indexer_max_chunk_size,
Expand Down
Loading