Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions tensorrt_llm/_torch/attention_backend/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,10 @@ def prepare(self) -> None:
) <= self.kv_cache_manager.max_seq_len, error_message

self.kv_lens_cuda_runtime = self.kv_lens_cuda[:self.num_seqs]
self.kv_lens_runtime = self.kv_lens[:self.num_seqs]
# Don't use self.kv_lens here because it includes extra tokens.
# Use actual KV length (without extra tokens) for kv_lens_runtime,
# which becomes host_past_key_value_lengths and eventually mMaxSeqLenKv.
self.kv_lens_runtime = kv_lens[:self.num_seqs]
self.prompt_lens_cuda_runtime = self.prompt_lens_cuda[:self.num_seqs]
self.prompt_lens_cpu_runtime = self.prompt_lens_cpu[:self.num_seqs]
self.host_request_types_runtime = self.host_request_types[:self.
Expand All @@ -898,13 +901,6 @@ def prepare_flash_mla(self) -> None:
self.block_ids_per_seq[:self.num_generations, :num_blocks].copy_(
block_ids_per_seq[self.num_contexts:], non_blocking=True)

self.kv_lens_cuda_runtime = self.kv_lens_cuda[:self.num_seqs]
self.kv_lens_runtime = self.kv_lens[:self.num_seqs]
self.prompt_lens_cuda_runtime = self.prompt_lens_cuda[:self.num_seqs]
self.prompt_lens_cpu_runtime = self.prompt_lens_cpu[:self.num_seqs]
self.host_request_types_runtime = self.host_request_types[:self.
num_seqs]

def pre_process_for_chunked_prefill(
self,
chunked_seq_len: torch.Tensor,
Expand Down
69 changes: 69 additions & 0 deletions tests/unittest/_torch/speculative/test_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
import tempfile
import unittest
from pathlib import Path
from unittest.mock import MagicMock

import pytest
import torch
from utils.llm_data import llm_models_root

from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm._torch.attention_backend.trtllm import TrtllmAttentionMetadata
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
KvCacheConfig)

Expand All @@ -22,6 +25,72 @@ def enforce_single_worker(monkeypatch):
yield


def test_kv_lens_runtime_with_eagle3_one_model():
"""
Validates that kv_lens_runtime correctly excludes num_extra_kv_tokens when
preparing attention metadata during EAGLE3 one-model speculative decoding.

Background:
- EAGLE3 reserves num_extra_kv_tokens = max_draft_len - 1 in KV cache for draft token management
- kv_lens_runtime becomes host_past_key_value_lengths, which eventually becomes mMaxSeqLenKv in FMHA kernel
- Bug: mMaxSeqLenKv was incorrectly set to actual_kv_length + num_extra_kv_tokens
- Fix: mMaxSeqLenKv should be set to actual_kv_length only (without extra tokens)

This test validates the fix by directly testing the prepare() logic.
"""

# Test parameters
num_seqs = 3
num_extra_kv_tokens = 7 # e.g., max_draft_len = 8, so extra = 7
prompt_lens = [50, 100, 75] # These represent actual KV lengths
seq_lens_q = [1, 1, 1] # 1 token each in generation
num_cached_tokens_per_seq = [
prompt_lens[i] - seq_lens_q[i] for i in range(num_seqs)
]

# Create a mock KV cache manager
mock_kv_cache_manager = MagicMock()
mock_kv_cache_manager.tokens_per_block = 32
mock_kv_cache_manager.num_pools = 1
mock_kv_cache_manager.max_blocks_per_seq = 16
mock_kv_cache_manager.max_batch_size = num_seqs
mock_kv_cache_manager.max_seq_len = 512 # Large enough to hold our test sequences
mock_kv_cache_manager.impl.copy_batch_block_offsets = MagicMock()

attn_metadata = TrtllmAttentionMetadata(
max_num_requests=num_seqs,
max_num_tokens=sum(seq_lens_q),
kv_cache_manager=mock_kv_cache_manager,
)

# Set required attributes
attn_metadata.request_ids = list(range(1, num_seqs + 1))
attn_metadata.prompt_lens = prompt_lens
attn_metadata._seq_lens = torch.tensor(seq_lens_q, dtype=torch.int32)
# seq_lens_kv is the number of new KV tokens being added in this step (for generation, same as seq_lens_q)
attn_metadata._seq_lens_kv = torch.tensor(seq_lens_q, dtype=torch.int32)

# Set KV cache params with num_extra_kv_tokens (EAGLE3 one-model case)
attn_metadata.kv_cache_params = KVCacheParams(
use_cache=True,
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
num_extra_kv_tokens=num_extra_kv_tokens)

attn_metadata.prepare()
actual_kv_lengths = torch.tensor(prompt_lens, dtype=torch.int32)

# kv_lens_runtime should equal actual KV lengths (without extra tokens)
kv_lens_runtime = attn_metadata.kv_lens_runtime[:num_seqs]
assert torch.equal(kv_lens_runtime, actual_kv_lengths), \
f"kv_lens_runtime should be {actual_kv_lengths.tolist()}, but got {kv_lens_runtime.tolist()}"

# Internal kv_lens should include extra tokens
kv_lens_internal = attn_metadata.kv_lens[:num_seqs]
expected_kv_lens_with_extra = actual_kv_lengths + num_extra_kv_tokens
assert torch.equal(kv_lens_internal, expected_kv_lens_with_extra), \
f"kv_lens should be {expected_kv_lens_with_extra.tolist()}, but got {kv_lens_internal.tolist()}"


@pytest.mark.parametrize(
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch,attention_dp",
[
Expand Down
Loading