diff --git a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py index 9b8efb98e84..4653a4e5a10 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py @@ -1049,12 +1049,11 @@ def prepare(metadata: DSAtrtllmAttentionMetadata): # Indexer should just process the current MLA chunk as a single chunk has_mla_chunked_prefill = ( metadata.enable_context_mla_with_cached_kv - and host_cached_tokens.sum().item() > 0 and metadata.runtime_features.chunked_prefill) if has_mla_chunked_prefill: - # The MLA has already split the sequence, here just process what's given (as a single chunk) - # Cached token info is derived from metadata.host_ctx_cached_token_indptr in prepare_one_prefill_chunk + # MLA chunked prefill is active - use single-chunk pattern for + # indexer prefill chunks. chunk_specs = [(i, 0, host_seq_lens[i].item(), host_seq_lens[:i].sum().item() if i > 0 else 0) for i in range(num_contexts)] @@ -1065,7 +1064,8 @@ def prepare(metadata: DSAtrtllmAttentionMetadata): ) ] else: - # Normal mode: use indexer's own chunking logic to prevent L^2 complexity when long-sequence is used. + # Use indexer's own chunking logic to prevent L^2 complexity of indexer MQA logits computation for long sequences. + # This is only used when MLA chunked prefill is not enabled. chunk_groups = split_prefill_chunks( host_seq_lens, metadata.indexer_max_chunk_size,