Skip to content

Commit b41a20b

Browse files
committed
Add support for KVCache reuse for DSv32
Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
1 parent cc5a058 commit b41a20b

File tree

2 files changed

+11
-15
lines changed

2 files changed

+11
-15
lines changed

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -930,22 +930,24 @@ def prepare(metadata: DSAtrtllmAttentionMetadata):
930930
start_idx=0,
931931
)
932932

933-
if len(chunk_groups) > 1:
933+
if len(chunk_groups
934+
) > 1 or metadata.enable_context_mla_with_cached_kv:
934935
metadata.indexer_prefill_chunks = [
935936
Indexer.prepare_one_prefill_chunk(
936937
metadata,
937938
chunk_specs,
938939
) for chunk_specs in chunk_groups
939940
]
940941
else:
941-
# Single chunk - use non-chunked fallback path
942942
metadata.indexer_prefill_chunks = None
943943

944-
host_cu_seqlen_ks, _ = compute_cu_seqlen_kv_bounds_with_cache(
944+
host_cu_seqlen_ks, host_cu_seqlen_ke = compute_cu_seqlen_kv_bounds_with_cache(
945945
host_seq_lens, num_contexts, num_ctx_tokens, host_cached_tokens)
946946

947947
metadata.cu_seqlen_ks[:num_ctx_tokens].copy_(host_cu_seqlen_ks,
948948
non_blocking=True)
949+
metadata.cu_seqlen_ke[:num_ctx_tokens].copy_(host_cu_seqlen_ke,
950+
non_blocking=True)
949951

950952
# Prepare for decode phase if there are generation requests
951953
if num_generations > 0:
@@ -1016,9 +1018,9 @@ def prepare(metadata: DSAtrtllmAttentionMetadata):
10161018
metadata.slot_mapping_scale[:total_tokens].copy_(
10171019
metadata.host_slot_mapping_scale[:total_tokens], non_blocking=True)
10181020

1019-
# Only when MLA chunked prefill is enabled, we need to gather the full KV for indexer's logit computation.
1021+
# When chunked prefill or KVCache reuse is enabled, we need to gather the full KV for indexer's logit computation.
10201022
# Indexer's own chunking does not need full KV gathering, instead it gathers only the current chunk with loop-based gathering.
1021-
_need_full_kv_gathering = num_contexts > 0 and has_mla_chunked_prefill
1023+
_need_full_kv_gathering = num_contexts > 0 and metadata.enable_context_mla_with_cached_kv
10221024
if _need_full_kv_gathering:
10231025
total_kv_len = metadata.host_ctx_kv_indptr[num_contexts].item()
10241026
total_kv_per_request = seq_lens[:

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2417,17 +2417,13 @@ def test_fp8_blockscale(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
24172417
if get_sm_version() == 100 or get_sm_version() == 103:
24182418
moe_backend = "DEEPGEMM" if moe_backend == "_DEFAULT" else moe_backend
24192419
moe_config = MoeConfig(backend=moe_backend, max_num_tokens=16384)
2420-
# TODO: Support block reuse for DeepSeek-V3.2
2421-
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
2422-
free_gpu_memory_fraction=0.6,
2420+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6,
24232421
tokens_per_block=64)
24242422
else:
24252423
if moe_backend != "_DEFAULT":
24262424
pytest.skip("Not supported MoE backend!")
24272425
moe_config = MoeConfig()
2428-
# TODO: Support block reuse for DeepSeek-V3.2
2429-
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
2430-
free_gpu_memory_fraction=0.7,
2426+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7,
24312427
tokens_per_block=64)
24322428

24332429
pytorch_config = dict(
@@ -2490,8 +2486,7 @@ def test_nvfp4_multi_gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
24902486
"MOE TRTLLM backend does not support SM version 120 or 121")
24912487

24922488
moe_config = MoeConfig(backend=moe_backend, max_num_tokens=16384)
2493-
kv_cache_config = KvCacheConfig(enable_block_reuse=True,
2494-
free_gpu_memory_fraction=0.7,
2489+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7,
24952490
tokens_per_block=64)
24962491
cuda_graph_config = CudaGraphConfig(
24972492
enable_padding=True,
@@ -2550,8 +2545,7 @@ def test_nvfp4_multi_gpus_chunked_prefill(self, tp_size, pp_size, ep_size,
25502545
"MOE TRTLLM backend does not support SM version 120 or 121")
25512546

25522547
moe_config = MoeConfig(backend=moe_backend, max_num_tokens=16384)
2553-
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
2554-
free_gpu_memory_fraction=0.7,
2548+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7,
25552549
tokens_per_block=64)
25562550
cuda_graph_config = CudaGraphConfig(
25572551
enable_padding=True,

0 commit comments

Comments
 (0)