Skip to content

Commit f1fe48d

Browse files
ShunkangzShunkang
authored andcommitted
[TRTLLM-9159][doc] Add KV Connector docs (NVIDIA#9043)
Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co> Co-authored-by: Shunkang <182541032+Shunkangz@users.noreply.github.co> Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com> Signed-off-by: Mike Iovine <miovine@nvidia.com>
1 parent a5623e4 commit f1fe48d

File tree

1 file changed

+84
-5
lines changed

1 file changed

+84
-5
lines changed

examples/llm-api/llm_kv_cache_connector.py

Lines changed: 84 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,84 @@
11
### :title KV Cache Connector
22
### :order 6
33
### :section Customization
4+
'''
5+
This script demonstrates the KV cache connector feature in TensorRT-LLM, which enables
6+
custom persistence and reuse of KV cache blocks across different LLM instances.
7+
8+
**Scenario:**
9+
The script implements a persistent KV cache connector that saves computed KV cache blocks
10+
to disk and loads them back in subsequent runs, eliminating redundant computation for
11+
recurring prompts.
12+
13+
**What is a KV Cache Connector?**
14+
15+
A KV cache connector is a customizable interface that allows you to:
16+
1. **Save KV Cache:** Persist computed KV cache blocks to an external storage
17+
(disk, database, distributed cache, etc.)
18+
2. **Load KV Cache:** Retrieve previously computed cache blocks instead of recomputing them
19+
3. **Share Cache Across Instances:** Reuse cache blocks across different LLM instances
20+
or sessions, unlike regular block reuse which is limited to a single instance
21+
22+
**How It Works:**
23+
24+
This example implements a `PersistentKvCacheConnector` with two key components:
25+
26+
* **PersistentKvCacheConnectorLeader (Scheduler):**
27+
- Hashes token sequences to create unique identifiers for each cache block
28+
- Checks if cached blocks exist on disk for incoming requests
29+
- Schedules load operations for cache hits
30+
- Schedules save operations for newly computed blocks
31+
32+
* **PersistentKvCacheConnectorWorker:**
33+
- Executes the actual load/save operations between GPU and disk
34+
- Loads cached blocks from disk files into GPU memory
35+
- Saves newly computed blocks from GPU to disk files
36+
37+
**Demonstration:**
38+
39+
The script processes the same prompt twice using two separate LLM instances:
40+
41+
1. **First Run (Instance 1):**
42+
- The LLM computes the KV cache for the input prompt
43+
- The connector saves the computed cache blocks to disk (as .pt files)
44+
- The generation completes and the LLM instance is destroyed
45+
46+
2. **Second Run (Instance 2):**
47+
- A new LLM instance is created with the same connector configuration
48+
- When processing the same prompt, the connector finds matching cache blocks on disk
49+
- The cache is loaded from disk instead of being recomputed
50+
- **Expected Outcome:** Faster prefill as cache blocks are loaded rather than computed
51+
- Both outputs should be identical, demonstrating deterministic cache reuse
52+
53+
**Key Benefits:**
54+
55+
- **Cross-Instance Cache Sharing:** Share computed caches across multiple LLM instances
56+
- **Persistent Storage:** Cache survives beyond the lifetime of a single LLM instance
57+
- **Custom Storage Backends:** Implement any storage mechanism (shown here: disk files)
58+
- **Reduced Computation:** Eliminate redundant KV cache computation for repeated prompts
59+
60+
**How to Run:**
61+
62+
```bash
63+
python llm_kv_cache_connector.py <model_path>
64+
```
65+
66+
Example:
67+
```bash
68+
python llm_kv_cache_connector.py meta-llama/Llama-3.1-8B-Instruct
69+
```
70+
71+
**Implementation Notes:**
72+
73+
- This example uses content-based hashing to identify cache blocks
74+
- Cache files are stored in a temporary directory (cleaned up after the demo)
75+
- The implementation is simplified and not optimized for production use
76+
- Does not support chunked prefill in this example
77+
- See `tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py` for the full connector interface
78+
79+
**NOTE:** This example connector implementation is designed for demonstration purposes
80+
and is NOT suitable for production use without additional optimizations and error handling.
81+
'''
482

583
import os
684
import sys
@@ -17,11 +95,6 @@
1795
from tensorrt_llm.bindings.internal.batch_manager import LlmRequest
1896
from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig, TorchLlmArgs
1997

20-
# This is a simple example of the use of the KV cache connector.
21-
# It persists KV cache contents into a folder, and can load them back on subsequent runs.
22-
# See tensorrt_llm/_torch/pyexecutor/connector.py for details about the KV cache connector interface.
23-
# NOTE: This example connector implementation is NOT suitable for production use.
24-
2598
CONNECTOR_CACHE_FOLDER_KEY = "CONNECTOR_CACHE_FOLDER"
2699

27100

@@ -198,6 +271,7 @@ def main(model: str):
198271

199272
this_module = __file__[__file__.rfind("/") + 1:__file__.rfind(".py")]
200273

274+
# --- KV Cache Connector Config ---
201275
kv_connector_config = KvCacheConnectorConfig(
202276
connector_module=this_module,
203277
connector_scheduler_class="PersistentKvCacheConnectorLeader",
@@ -207,6 +281,7 @@ def main(model: str):
207281
connector_cache_dir = TemporaryDirectory()
208282
os.environ[CONNECTOR_CACHE_FOLDER_KEY] = connector_cache_dir.name
209283

284+
# Create LLM instance with KV Cache Connector
210285
llm = LLM(model=model,
211286
backend="pytorch",
212287
cuda_graph_config=None,
@@ -220,6 +295,7 @@ def main(model: str):
220295

221296
sampling_params = SamplingParams(max_tokens=32)
222297

298+
# Generate text with the first LLM instance and save the kv cache blocks by the connector.
223299
output = llm.generate([test_text], sampling_params)
224300
text0 = output[0].outputs[0].text
225301

@@ -228,16 +304,19 @@ def main(model: str):
228304

229305
del llm
230306

307+
# Create a new LLM instance with the same connector configuration
231308
llm = LLM(model=model,
232309
backend="pytorch",
233310
cuda_graph_config=None,
234311
kv_connector_config=kv_connector_config)
235312

313+
# Generate text with the second LLM instance and it should reuse the kv cache blocks from the connector.
236314
output = llm.generate([test_text], sampling_params)
237315
text1 = output[0].outputs[0].text
238316

239317
print("Second output (using connector cache): ", text1)
240318

319+
# Verify that the two outputs are identical
241320
assert text0 == text1
242321

243322
connector_cache_dir.cleanup()

0 commit comments

Comments
 (0)