Skip to content

Commit a08eb81

Browse files
authored
[None][feat] Add RocketKV usage doc and e2e accuracy test on LongBenchV2 (NVIDIA#9572)
Signed-off-by: yuhangh <58161490+heyuhhh@users.noreply.github.com>
1 parent 097ac32 commit a08eb81

File tree

5 files changed

+236
-39
lines changed

5 files changed

+236
-39
lines changed
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# RocketKV Sparse Attention
2+
3+
This document details enabling RocketKV sparse attention within TensorRT LLM.
4+
5+
RocketKV is a training-free, two-stage KV cache compression method designed to accelerate long-context LLM inference. It combines permanent KV token eviction (in context phase) with dynamic KV token selection (in generation phase) to significantly reduce memory bandwidth usage and increase throughput while maintaining high accuracy.
6+
7+
For more technical details, please refer to the paper: [RocketKV: Accelerating Long-Context LLM Inference via Two-Stage KV Cache Compression](https://arxiv.org/pdf/2502.14051). Here is an official implement which provides a reference: [RocketKV Repo](https://github.com/NVlabs/RocketKV).
8+
9+
## Overview
10+
11+
In Transformer-based LLM inference, the KV cache grows linearly with sequence length, becoming a major bottleneck. RocketKV mitigates this issue through a two-stage process:
12+
13+
1. **Context Phase (Stage 1):** It performs **permanent KV cache eviction**. Instead of storing the full history, it selects and keeps a `prompt_budget` of the most important tokens based on attention scores.
14+
2. **Generation Phase (Stage 2):** It utilizes a **dynamic Top-K token selection**. It maintains a lightweight, compressed auxiliary cache (KT Cache) to dynamically predict which tokens of the KV cache are relevant for the current token, and loading only those tokens to do the attention computation.
15+
16+
RocketKV is integrated into TensorRT LLM as a specialized attention backend, accessible via the LLM API. Specifically, the core sparse KV prediction kernels are implemented using **Triton** kernels, achieving highly optimized performance on modern NVIDIA GPUs.
17+
18+
## Support Matrix
19+
20+
* GPU Compute Capability >= 10.0 (Blackwell or newer)
21+
* FP16 / BF16 / FP8
22+
* Paged KV Cache
23+
* Tensor Parallel
24+
* Cuda Graph
25+
26+
**Note:**
27+
1. RocketKV currently requires `enable_block_reuse=False` in the KV cache configuration, as the sparse eviction logic is incompatible with standard block reuse mechanisms.
28+
2. RocketKV doesn't support `enable_chunked_prefill=True` for now.
29+
3. RocketKV doesn't support *disagg-serving* as well, since it needs the KV cache transmission from prefill engine to the decode engine. But currently RocketKV uses a python kt cache manager and it cannot support this transmission.
30+
31+
## Usage
32+
33+
To enable RocketKV, configure `RocketSparseAttentionConfig` and pass it to the `LLM` class constructor.
34+
35+
### Python API
36+
37+
Integrate RocketKV into your workflows using the `tensorrt_llm.llmapi` interface.
38+
39+
```python
40+
from tensorrt_llm import LLM, SamplingParams
41+
from tensorrt_llm.llmapi import RocketSparseAttentionConfig, KvCacheConfig
42+
43+
# 1. Define the RocketKV Sparse Attention Configuration
44+
rocket_config = RocketSparseAttentionConfig(
45+
window_size=32, # Size of the recent window to always keep
46+
kernel_size=63, # Pooling kernel size for importance scoring
47+
prompt_budget=2048, # Number of tokens to keep from the prompt (Stage 1)
48+
topk=64, # Number of tokens to retrieve during generation (Stage 2)
49+
topr=128, # Number of query channels to keep for scoring
50+
kt_cache_dtype='float8_e5m2' # Dtype for the auxiliary Key-Token cache
51+
)
52+
53+
# 2. Initialize the LLM with the config and 'pytorch' backend
54+
# Note: Block reuse must be disabled for RocketKV
55+
kv_config = KvCacheConfig(enable_block_reuse=False)
56+
57+
llm = LLM(
58+
model="<path_to_model>",
59+
backend='pytorch', # RocketKV currently requires the PyTorch backend
60+
sparse_attention_config=rocket_config,
61+
kv_cache_config=kv_config,
62+
)
63+
64+
# 3. Generate
65+
prompts = ["To be or not to be, that is the question."]
66+
sampling_params = SamplingParams(max_tokens=128)
67+
outputs = llm.generate(prompts, sampling_params)
68+
```
69+
70+
### Running the Example Script
71+
72+
We provide a reference script `examples/llm-api/llm_sparse_attention.py` to demonstrate RocketKV capabilities.
73+
74+
**Example Command:**
75+
76+
```bash
77+
# Adjust --model_path to your local Llama checkpoint
78+
python3 ../llm-api/llm_sparse_attention.py \
79+
--model_path <path_to_model> \
80+
--algo ROCKETKV \
81+
--attention_backend TRTLLM \
82+
--window_size 32 \
83+
--kernel_size 63 \
84+
--prompt_budget 2048 \
85+
--topk 64 \
86+
--topr 128 \
87+
--kt_cache_dtype float8_e5m2 \
88+
--max_seq_len 10240 \
89+
--max_num_tokens 10240 \
90+
--max_new_tokens 128
91+
```
92+
93+
94+
### Usage with `trtllm-bench` and `trtllm-serve`
95+
96+
Sparse attention options must be specified via `--extra_llm_api_options config.yaml` for both `trtllm-bench` and `trtllm-serve`. All sparse attetnion options can be specified in this YAML file and the argument names/valid values are the same as in their corresponding configuration described in the Configuration Arguments section. For example, a YAML configuration could look like this:
97+
98+
```
99+
backend: pytorch
100+
attn_backend: TRTLLM
101+
sparse_attention_config:
102+
algorithm: rocket
103+
kt_cache_dtype: float8_e5m2
104+
window_size: 32
105+
prompt_budget: 2048
106+
kv_cache_config:
107+
enable_block_reuse: false
108+
enable_chunked_prefill: false
109+
```
110+
111+
Run the command with the config file:
112+
```bash
113+
trtllm-eval/trtllm-bench/trtllm-serve --model <model_path> --extra_llm_api_options extra_config.yaml ...
114+
```
115+
116+
For example, users can evaluate a model with trtllm-eval on LongBenchV2 task like this:
117+
118+
```bash
119+
trtllm-eval --model <path_to_model> --extra_llm_api_options extra_config.yaml longbench_v2 --max_output_length 1024 ...
120+
```
121+
122+
## Configuration Arguments
123+
124+
The `RocketSparseAttentionConfig` allows fine-grained control over compression ratios and performance trade-offs:
125+
126+
* **`prompt_budget`** (int, default=2048): The number of tokens to retain from the input prompt (context). RocketKV compresses the prompt to this size by evicting less important tokens based on importance scores.
127+
* **`topk`** (int, default=64): The number of KT pages to select dynamically during the generation phase. Note that the selection is performed at the granularity of KT cache pages, but the actual attention kernel retrieves data based on the granularity of KV cache page size.
128+
* **`topr`** (int/float, default=128): The number of query feature dimensions to use when computing the relevance score between Query and KT Cache. This acts as a dimensionality reduction to speed up the selection process. However, it's recommended to set it equal to `head_dim` to skip `topr_filter` computations for better performance and accuracy.
129+
* **`window_size`** (int, default=32): The size of the sliding window in RocketKV. In the context phase, RocketKV uses the last `window_size` tokens of the Query and the Key prefix to compute importance scores for eviction. These recent tokens are always retained in the cache, and `prompt_budget-window_size` important tokens in the prefix are retained in the cache also.
130+
* **`kernel_size`** (int, default=63): The size of the 1D max-pooling kernel used in the context phase. It smooths attention scores to better identify locally important regions rather than just isolated high-score tokens.
131+
* **`kt_cache_dtype`** (str, default='float8_e5m2'): The data type for the auxiliary "Key-Token" (KT) cache used for relevance prediction.
132+
* `float8_e5m2`: Recommended. Provides memory savings for the auxiliary structure and speedup for the prediction kernels.
133+
* `bfloat16`: Standard precision.
134+
* **`page_size`** (int, default=4): The granularity of the sparse token selection (KT page). Currently, only **powers of 2** are supported due to Triton kernel limitations. Accuracy is generally maintained well for `page_size <= 4`.

tensorrt_llm/evaluate/longbench_v2.py

Lines changed: 49 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -51,23 +51,26 @@ class LongBenchV2(Evaluator):
5151
DIFFICULTIES = ['easy', 'hard']
5252
LENGTHS = ['short', 'medium', 'long']
5353

54-
def __init__(self,
55-
dataset_path: str = 'THUDM/LongBench-v2',
56-
prompts_dir: Optional[str] = None,
57-
num_samples: Optional[int] = None,
58-
start_idx: int = 0,
59-
difficulty: Optional[str] = None,
60-
length: str = 'medium',
61-
domain: Optional[str] = None,
62-
cot: bool = False,
63-
no_context: bool = False,
64-
rag: int = 0,
65-
max_len: int = 128000,
66-
output_dir: Optional[str] = None,
67-
random_seed: int = 0,
68-
apply_chat_template: bool = False,
69-
system_prompt: Optional[str] = None,
70-
chat_template_kwargs: Optional[dict[str, Any]] = None):
54+
def __init__(
55+
self,
56+
dataset_path: str = 'THUDM/LongBench-v2',
57+
prompts_dir: Optional[str] = None,
58+
num_samples: Optional[int] = None,
59+
start_idx: int = 0,
60+
difficulty: Optional[str] = None,
61+
length: str = 'medium',
62+
domain: Optional[str] = None,
63+
cot: bool = False,
64+
no_context: bool = False,
65+
rag: int = 0,
66+
max_len: int = 128000,
67+
output_dir: Optional[str] = None,
68+
random_seed: int = 0,
69+
apply_chat_template: bool = False,
70+
system_prompt: Optional[str] = None,
71+
max_output_length: int = 32000,
72+
chat_template_kwargs: Optional[dict[str, Any]] = None,
73+
):
7174
"""Initialize LongBench v2 evaluator.
7275
7376
Args:
@@ -86,6 +89,7 @@ def __init__(self,
8689
random_seed: Random seed for reproducibility
8790
apply_chat_template: Whether to apply model's chat template
8891
system_prompt: System prompt to prepend
92+
max_output_length: Maximum output length in tokens. Should keep this value as small as possible to avoid unnecessary truncation.
8993
chat_template_kwargs: Chat template kwargs as JSON string
9094
"""
9195
super().__init__(random_seed=random_seed,
@@ -103,6 +107,7 @@ def __init__(self,
103107
self.no_context = no_context
104108
self.rag = rag
105109
self.max_len = max_len
110+
self.max_output_length = max_output_length
106111
self.output_dir = output_dir
107112

108113
# Will be set during evaluation
@@ -307,10 +312,11 @@ def _post_process(self, pred: str) -> str:
307312
return pred
308313

309314
def _truncate_prompt(self, prompt: str, tokenizer: Any) -> str:
310-
"""Truncate prompt to max_len tokens using needle-in-haystack strategy.
315+
"""Truncate prompt using needle-in-haystack strategy.
311316
312-
If the prompt exceeds max_len, it takes the first half and last half
317+
If the prompt exceeds (max_len - max_output_length), it takes the first half and last half
313318
to preserve both context beginning and end.
319+
We need to minus max_output_length from max_len to reserve budget for output tokens.
314320
315321
Args:
316322
prompt: The prompt string to truncate
@@ -325,8 +331,9 @@ def _truncate_prompt(self, prompt: str, tokenizer: Any) -> str:
325331
try:
326332
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
327333

328-
if len(input_ids) > self.max_len:
329-
half = self.max_len // 2
334+
max_input_len = self.max_len - self.max_output_length
335+
if len(input_ids) > max_input_len:
336+
half = max_input_len // 2
330337
truncated_ids = input_ids[:half] + input_ids[-half:]
331338
prompt = tokenizer.decode(truncated_ids,
332339
skip_special_tokens=True)
@@ -791,7 +798,8 @@ def _save_results(self, results: List[Dict], metrics: Dict[str, float]):
791798
type=int,
792799
default=128000,
793800
help=
794-
"Maximum prompt length in tokens for truncation when building prompts.")
801+
"Maximum input and output length in tokens for truncation when building prompts."
802+
)
795803
@click.option("--output_dir",
796804
type=str,
797805
default=None,
@@ -843,22 +851,25 @@ def command(ctx, dataset_path: str, prompts_dir: Optional[str],
843851
temperature=0.6,
844852
top_p=0.95)
845853

846-
evaluator = LongBenchV2(dataset_path=dataset_path,
847-
prompts_dir=prompts_dir,
848-
num_samples=num_samples,
849-
start_idx=start_idx,
850-
difficulty=difficulty,
851-
length=length,
852-
domain=domain,
853-
cot=cot,
854-
no_context=no_context,
855-
rag=rag,
856-
max_len=max_len,
857-
output_dir=output_dir,
858-
random_seed=random_seed,
859-
apply_chat_template=apply_chat_template,
860-
system_prompt=system_prompt,
861-
chat_template_kwargs=chat_template_kwargs)
854+
evaluator = LongBenchV2(
855+
dataset_path=dataset_path,
856+
prompts_dir=prompts_dir,
857+
num_samples=num_samples,
858+
start_idx=start_idx,
859+
difficulty=difficulty,
860+
length=length,
861+
domain=domain,
862+
cot=cot,
863+
no_context=no_context,
864+
rag=rag,
865+
max_len=max_len,
866+
output_dir=output_dir,
867+
random_seed=random_seed,
868+
apply_chat_template=apply_chat_template,
869+
system_prompt=system_prompt,
870+
max_output_length=max_output_length,
871+
chat_template_kwargs=chat_template_kwargs,
872+
)
862873

863874
evaluator.evaluate(llm, sampling_params)
864875
llm.shutdown()

tests/integration/defs/accuracy/references/longbench_v2.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@ DeepSeek-R1-0528:
77
kv_cache_quant_algo: FP8
88
spec_dec_algo: MTP
99
accuracy: 52.093
10+
meta-llama/Llama-3.1-8B-Instruct:
11+
- accuracy: 26.48
12+
sigma: 25.8

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
from tensorrt_llm.llmapi import (AutoDecodingConfig, CudaGraphConfig,
2626
EagleDecodingConfig, KvCacheConfig, MoeConfig,
2727
MTPDecodingConfig, NGramDecodingConfig,
28-
SamplingParams, TorchCompileConfig)
28+
RocketSparseAttentionConfig, SamplingParams,
29+
TorchCompileConfig)
2930
from tensorrt_llm.quantization import QuantAlgo
3031

3132
from ..conftest import (get_device_count, get_device_memory, llm_models_root,
@@ -4606,3 +4607,50 @@ def test_auto_dtype(self):
46064607
max_seq_len=4096) as llm:
46074608
task = GSM8K(self.MODEL_NAME)
46084609
task.evaluate(llm)
4610+
4611+
4612+
@skip_pre_blackwell
4613+
class TestLlama3_1_8B_Instruct_LongBenchV2(LlmapiAccuracyTestHarness):
4614+
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
4615+
MODEL_PATH = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct/"
4616+
4617+
def test_auto_dtype(self):
4618+
model_dir = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct/"
4619+
if not os.path.exists(model_dir):
4620+
pytest.skip(f"Model directory {model_dir} does not exist")
4621+
4622+
# Configure model settings
4623+
kv_cache_config = KvCacheConfig(enable_block_reuse=False)
4624+
4625+
cuda_graph_config = CudaGraphConfig(enable_padding=True,
4626+
max_batch_size=64)
4627+
4628+
sparse_attention_config = RocketSparseAttentionConfig(
4629+
kt_cache_dtype="float8_e5m2", )
4630+
4631+
pytorch_config = dict(cuda_graph_config=cuda_graph_config,
4632+
kv_cache_config=kv_cache_config,
4633+
sparse_attention_config=sparse_attention_config,
4634+
enable_chunked_prefill=False)
4635+
4636+
MAX_LEN = 128000
4637+
MAX_NEW_TOKENS = 1024
4638+
4639+
with LLM(model_dir,
4640+
max_seq_len=MAX_LEN,
4641+
max_num_tokens=128000,
4642+
max_batch_size=64,
4643+
**pytorch_config) as llm:
4644+
task = LongBenchV2(self.MODEL_NAME)
4645+
4646+
sampling_params = SamplingParams(
4647+
max_tokens=MAX_NEW_TOKENS,
4648+
temperature=0.8,
4649+
top_p=0.95,
4650+
)
4651+
4652+
extra_evaluator_kwargs = dict(max_len=MAX_LEN,
4653+
max_output_length=MAX_NEW_TOKENS)
4654+
task.evaluate(llm,
4655+
sampling_params=sampling_params,
4656+
extra_evaluator_kwargs=extra_evaluator_kwargs)

tests/integration/test_lists/test-db/l0_b200.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ l0_b200:
5656
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[mxfp8-latency-CUTLASS]
5757
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a16_mxfp4[latency-TRTLLM]
5858
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp1-cutlass]
59+
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B_Instruct_LongBenchV2::test_auto_dtype
5960
- disaggregated/test_workers.py::test_workers_kv_cache_aware_router_eviction[TinyLlama-1.1B-Chat-v1.0] # nvbugs 5300551
6061
- test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-NVFP4-nvfp4-quantized/Meta-Llama-3.1-8B]
6162
- test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-FP8-llama-3.1-model/Llama-3.1-8B-Instruct-FP8]

0 commit comments

Comments
 (0)