Skip to content

Commit 4ab3889

Browse files
committed
Update GitHub pages in root to v1.2.0rc0
1 parent 4b90534 commit 4ab3889

File tree

359 files changed

+87053
-22531
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

359 files changed

+87053
-22531
lines changed

.buildinfo

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# Sphinx build info version 1
22
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
3-
config: 2820ecdc6d7d98c139c21bcb2df54fee
3+
config: 05441684cb2c0903bdac9ebb5abe267d
44
tags: 645f666f9bcd5a90fca523b33c5a78b7

_cpp_gen/executor.html

Lines changed: 8725 additions & 8669 deletions
Large diffs are not rendered by default.

_cpp_gen/runtime.html

Lines changed: 8226 additions & 8224 deletions
Large diffs are not rendered by default.
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
### :title KV Cache Offloading
2+
### :order 6
3+
### :section Customization
4+
'''
5+
This script demonstrates the effectiveness of KV cache host offloading in TensorRT-LLM.
6+
7+
**Scenario:**
8+
The script simulates a scenario where the GPU's KV cache is severely limited,
9+
while multiple requests with recurring prompts (like system prompts) are processed.
10+
11+
1. **Constrained GPU Cache:** The GPU KV cache is configured to be very small,
12+
only large enough to hold the state for a single request.
13+
2. **Alternating Prompts:** Four requests are sent sequentially (batch size of 1)
14+
with two distinct prompts in an A, B, A, B pattern.
15+
3. **Cache Eviction:** Due to the small GPU cache, processing prompt B will
16+
force the eviction of the cache generated for prompt A.
17+
18+
**Demonstration:**
19+
20+
* **Without Offloading (Default):**
21+
- When the first prompt 'A' is processed, its KV cache is stored on the GPU.
22+
- When prompt 'B' arrives, the cache manager needs space and discards the cache for 'A'.
23+
- When prompt 'A' is sent again, its cache must be recomputed from scratch.
24+
- **Expected Outcome:** The log will show `reused blocks: 0` and `cache hit rate: 0`.
25+
26+
* **With Offloading (`--enable_offloading`):**
27+
- When prompt 'B' arrives, the cache for 'A' is not discarded but is instead
28+
*offloaded* from the fast GPU VRAM to the slower (but larger) host CPU RAM.
29+
- When prompt 'A' is sent again, its KV cache is loaded back from host RAM
30+
to the GPU, which is significantly faster than recomputing it.
31+
- **Expected Outcome:** The log will show positive values for `reused blocks`
32+
and a non-zero `cache hit rate`, confirming that the cache was successfully
33+
reused from the host.
34+
35+
**How to Run & Verify:**
36+
37+
1. **Without Offloading:**
38+
```bash
39+
TLLM_LOG_LEVEL=DEBUG python llm_kv_cache_offloading.py 2>&1 | tee offloading_disabled.log
40+
```
41+
(Check the log for zero reuse)
42+
43+
2. **With Offloading:**
44+
```bash
45+
TLLM_LOG_LEVEL=DEBUG python llm_kv_cache_offloading.py --enable_offloading 2>&1 | tee offloading_enabled.log
46+
```
47+
(Check the log for non-zero reuse)
48+
'''
49+
50+
import argparse
51+
52+
from tensorrt_llm import LLM, SamplingParams
53+
from tensorrt_llm.llmapi import KvCacheConfig
54+
55+
56+
def main(args):
57+
# Define two distinct prompts to simulate different requests or system prompts.
58+
prompt_a = (
59+
"Returns the per-iterations statistics computed since last call to this method. "
60+
"Contains at most iter_stats_max_iterations iterations.")
61+
prompt_b = ("Use for skipping decoding step for non generation model, "
62+
"and return the batch_output (such as mm_embeddings)")
63+
64+
# Use a batch size of 1 to process requests sequentially, making the cache
65+
# eviction and reuse cycle easy to observe.
66+
max_batch_size = 1
67+
max_seq_len = 256
68+
69+
# --- KV Cache Configuration ---
70+
# Set a small GPU KV cache size (in number of tokens). This is crucial for the demo,
71+
# as it's only large enough to hold the KV cache for a single request.
72+
kv_cache_max_tokens = 256
73+
# Define the size of a single cache block.
74+
kv_cache_page_size = 16
75+
# Enable a 1 GB host cache if offloading is requested, otherwise disable it (size 0).
76+
# This is the key toggle for the experiment.
77+
kv_cache_host_size = 1024**3 if args.enable_offloading else 0
78+
79+
sampling_params = SamplingParams(max_tokens=max_seq_len)
80+
81+
llm = LLM(
82+
model="Qwen/Qwen3-8B",
83+
max_batch_size=max_batch_size,
84+
max_seq_len=max_seq_len,
85+
kv_cache_config=KvCacheConfig(
86+
enable_block_reuse=True, # Enable reuse of cached blocks
87+
max_tokens=kv_cache_max_tokens, # Max tokens in GPU cache
88+
tokens_per_block=kv_cache_page_size,
89+
host_cache_size=kv_cache_host_size # Host cache size for offloading
90+
))
91+
92+
# Process four requests sequentially using two distinct prompts (A, B, A, B).
93+
# This pattern is designed to showcase the cache eviction and reuse behavior.
94+
print("--- First Round ---")
95+
# 1. Process prompt A. Its cache is stored on the GPU.
96+
output_a = llm.generate(prompt_a, sampling_params)
97+
print(
98+
f"Prompt: {output_a.prompt!r}, Generated text: {output_a.outputs[0].text!r}"
99+
)
100+
# 2. Process prompt B. Its cache replaces/offloads A's cache.
101+
output_b = llm.generate(prompt_b, sampling_params)
102+
print(
103+
f"Prompt: {output_b.prompt!r}, Generated text: {output_b.outputs[0].text!r}"
104+
)
105+
106+
print("\n--- Second Round ---")
107+
# 3. Process prompt A again.
108+
# - Without offloading: Must recompute from scratch.
109+
# - With offloading: Recovers cache from host RAM.
110+
output_a = llm.generate(prompt_a, sampling_params)
111+
print(
112+
f"Prompt: {output_a.prompt!r}, Generated text: {output_a.outputs[0].text!r}"
113+
)
114+
# 4. Process prompt B again.
115+
# - Without offloading: Must recompute from scratch.
116+
# - With offloading: Recovers cache from host RAM.
117+
output_b = llm.generate(prompt_b, sampling_params)
118+
print(
119+
f"Prompt: {output_b.prompt!r}, Generated text: {output_b.outputs[0].text!r}"
120+
)
121+
122+
llm.shutdown()
123+
124+
125+
if __name__ == "__main__":
126+
parser = argparse.ArgumentParser(
127+
description=
128+
"A script to demonstrate the effectiveness of KV cache host offloading."
129+
)
130+
parser.add_argument('--enable_offloading',
131+
action='store_true',
132+
help='Enable host RAM for KV cache offloading.')
133+
args = parser.parse_args()
134+
main(args)

_downloads/b509390ba70e52fabb10dbd9d15d5118/attention.py

Lines changed: 49 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from torch import nn
77

8-
from tensorrt_llm._utils import get_sm_version
8+
from tensorrt_llm._utils import get_sm_version, is_sm_100f
99
from tensorrt_llm.logger import logger
1010
from tensorrt_llm.mapping import Mapping
1111

@@ -24,7 +24,7 @@
2424
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
2525
from .multi_stream_utils import maybe_execute_in_parallel
2626
from .rms_norm import RMSNorm
27-
from .rotary_embedding import RotaryEmbedding
27+
from .rotary_embedding import MRotaryEmbedding, RotaryEmbedding
2828

2929

3030
def extract_extra_attrs(layer_idx: str, attn_type: str):
@@ -67,6 +67,16 @@ def extract_extra_attrs(layer_idx: str, attn_type: str):
6767
return metadata, attn_layer
6868

6969

70+
@torch.compile
71+
def compiled_copy_(dst, src):
72+
dst.copy_(src)
73+
74+
75+
@torch.compile
76+
def compiled_cat(tensors, dim):
77+
return torch.cat(tensors, dim)
78+
79+
7080
@torch.library.custom_op("trtllm::attn_custom_op_inplace",
7181
mutates_args=("output", ))
7282
def attn_custom_op_inplace(
@@ -271,11 +281,19 @@ def __init__(
271281

272282
self.rotary_emb = None
273283
if not self.rope_fusion and self.pos_embd_params is not None:
274-
self.rotary_emb = RotaryEmbedding(
275-
self.pos_embd_params.rope,
276-
head_dim=self.head_dim,
277-
is_neox=self.pos_embd_params.is_neox,
278-
)
284+
if self.pos_embd_params.type.is_mrope():
285+
self.rotary_emb = MRotaryEmbedding(
286+
self.pos_embd_params.rope,
287+
head_dim=self.head_dim,
288+
is_neox=self.pos_embd_params.is_neox,
289+
mrope_section=self.pos_embd_params.mrope_section,
290+
)
291+
else:
292+
self.rotary_emb = RotaryEmbedding(
293+
self.pos_embd_params.rope,
294+
head_dim=self.head_dim,
295+
is_neox=self.pos_embd_params.is_neox,
296+
)
279297

280298
self.attn = create_attention(
281299
self.attn_backend,
@@ -301,6 +319,12 @@ def create_weights(self):
301319
# which could be modified after __init__
302320
self.attn.update_quant_config(self.quant_config)
303321

322+
self.o_proj.create_weights()
323+
self.has_quant_scale = (self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4
324+
or self.o_proj.has_fp8_block_scales
325+
or self.o_proj.has_fp8_rowwise
326+
or self.o_proj.has_w4a8_nvfp4_fp8)
327+
304328
def split_qkv(self, q, k=None, v=None):
305329
if k is None and v is None:
306330
q, k, v = q.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
@@ -320,12 +344,8 @@ def create_output(self, q: torch.Tensor):
320344
out_dtype = q.dtype
321345

322346
if self.attn_backend == "TRTLLM":
323-
has_quant_scale = (self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4
324-
or self.o_proj.has_fp8_block_scales
325-
or self.o_proj.has_fp8_rowwise
326-
or self.o_proj.has_w4a8_nvfp4_fp8)
327-
if has_quant_scale and (self.attn.has_fp8_kv_cache
328-
or self.attn.has_fp4_kv_cache):
347+
if self.has_quant_scale and (self.attn.has_fp8_kv_cache
348+
or self.attn.has_fp4_kv_cache):
329349
out_dtype = torch.float8_e4m3fn
330350
output = q.new_empty([num_tokens, hidden_size], dtype=out_dtype)
331351
return output
@@ -356,11 +376,7 @@ def _attn_impl(
356376

357377
out_scale = None
358378
out_scale_sf = None
359-
has_quant_scale = (self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4
360-
or self.o_proj.has_fp8_block_scales
361-
or self.o_proj.has_fp8_rowwise
362-
or self.o_proj.has_w4a8_nvfp4_fp8)
363-
if has_quant_scale:
379+
if self.has_quant_scale:
364380
out_scale = self.o_proj.inv_input_scale
365381
if self.o_proj.has_nvfp4 and self.support_nvfp4_output and enable_attn_nvfp4_output:
366382
out_scale_sf = self.o_proj.input_scale
@@ -585,7 +601,7 @@ def fp8_block_scaling_bmm_out(
585601
output)
586602
out.copy_(output)
587603

588-
elif sm_version == 100:
604+
elif is_sm_100f(sm_version):
589605
torch.bmm(mat1.transpose(0, 1), mat2_dequant.transpose(1, 2), out=out)
590606
else:
591607
raise NotImplementedError(f"SM{sm_version} is not supported")
@@ -858,6 +874,9 @@ def create_weights(self):
858874
self.mha.update_quant_config(self.quant_config)
859875
self.mqa.update_quant_config(self.quant_config)
860876

877+
# Although we use FP8 MLA for context/generation phase, the output is still in BF16
878+
self.out_scale = None
879+
861880
# k_b_proj_trans's dtype must be consistent with self.kv_b_proj,
862881
# which can be modified after __init__
863882
has_fp8_block_scales = (
@@ -900,7 +919,7 @@ def create_weights(self):
900919
),
901920
requires_grad=False,
902921
)
903-
if get_sm_version() == 100:
922+
if is_sm_100f():
904923
assert self.dtype == torch.bfloat16
905924
self.k_b_proj_trans_dequant = nn.Parameter(
906925
torch.empty(
@@ -1054,24 +1073,21 @@ def forward_context_default(
10541073
)
10551074

10561075
k = torch.empty_like(q).view(-1, self.num_heads, self.qk_head_dim)
1057-
k[..., :self.qk_nope_head_dim] = k_nope.view(-1, self.num_heads,
1058-
self.qk_nope_head_dim)
1076+
compiled_copy_(k[..., :self.qk_nope_head_dim],
1077+
k_nope.view(-1, self.num_heads, self.qk_nope_head_dim))
10591078
if self.apply_rotary_emb:
10601079
k[..., self.qk_nope_head_dim:] = k_pe.view(-1, 1,
10611080
self.qk_rope_head_dim)
10621081
k = k.view(-1, self.num_heads * self.qk_head_dim)
10631082

1064-
# out_scale = getattr(self.o_proj, "inv_input_scale", None)
1065-
out_scale = None # Currently we use BF16 MHA for context phase
1066-
10671083
attn_output = self.mha.forward(
10681084
q,
10691085
k,
10701086
v,
10711087
attn_metadata,
10721088
attention_input_type=AttentionInputType.context_only,
10731089
latent_cache=latent_cache,
1074-
out_scale=out_scale,
1090+
out_scale=self.out_scale,
10751091
output=output,
10761092
)
10771093

@@ -1116,7 +1132,7 @@ def forward_context_with_cached_kv(
11161132
full_k_nope = full_k_nope.view(-1, self.num_heads,
11171133
self.qk_nope_head_dim)
11181134
full_k_pe = full_k_pe.view(-1, 1, self.qk_rope_head_dim)
1119-
full_k = torch.cat(
1135+
full_k = compiled_cat(
11201136
(full_k_nope, full_k_pe.expand(-1, self.num_heads, -1)), dim=-1)
11211137
full_k = full_k.view(-1, self.num_heads * self.qk_head_dim)
11221138

@@ -1126,9 +1142,6 @@ def forward_context_with_cached_kv(
11261142
full_kv = None
11271143
full_k_nope = None
11281144

1129-
# out_scale = getattr(self.o_proj, "inv_input_scale", None)
1130-
out_scale = None # Currently we use BF16 MHA for context phase
1131-
11321145
# latent_cache must be None to differentiate from normal context phase,
11331146
# so that we can skip applying RoPE and appending KV cache inside attention op
11341147
attn_output = self.mha.forward(
@@ -1138,7 +1151,7 @@ def forward_context_with_cached_kv(
11381151
attn_metadata,
11391152
attention_input_type=AttentionInputType.context_only,
11401153
latent_cache=None,
1141-
out_scale=out_scale,
1154+
out_scale=self.out_scale,
11421155
output=output,
11431156
)
11441157

@@ -1214,7 +1227,7 @@ def forward_context_with_chunked_prefill(
12141227
chunked_k_nope = chunked_k_nope.view(-1, self.num_heads,
12151228
self.qk_nope_head_dim)
12161229
chunked_k_pe = chunked_k_pe.view(-1, 1, self.qk_rope_head_dim)
1217-
chunked_k = torch.cat(
1230+
chunked_k = compiled_cat(
12181231
(chunked_k_nope, chunked_k_pe.expand(-1, self.num_heads, -1)),
12191232
dim=-1)
12201233
chunked_k = chunked_k.view(-1, self.num_heads * self.qk_head_dim)
@@ -1232,7 +1245,6 @@ def forward_context_with_chunked_prefill(
12321245
loop_idx]
12331246
attn_metadata.host_total_kv_lens[0] = total_ctx_chunked_tokens
12341247

1235-
out_scale = None
12361248
# do not apply mask for attention within loop
12371249
# latent_cache must be None to differentiate from normal context phase,
12381250
# so that we can skip applying RoPE and appending KV cache inside attention op
@@ -1243,7 +1255,7 @@ def forward_context_with_chunked_prefill(
12431255
attn_metadata,
12441256
attention_input_type=AttentionInputType.context_only,
12451257
latent_cache=None,
1246-
out_scale=out_scale,
1258+
out_scale=self.out_scale,
12471259
attention_mask=PredefinedAttentionMask.FULL,
12481260
softmax_stats_tensor=self.temp_softmax_stats_tensor,
12491261
chunked_prefill_buffer_batch_size=attn_metadata.
@@ -1273,7 +1285,7 @@ def forward_context_with_chunked_prefill(
12731285

12741286
k_nope = k_nope.view(-1, self.num_heads, self.qk_nope_head_dim)
12751287
k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim)
1276-
k = torch.cat((k_nope, k_pe.expand(-1, self.num_heads, -1)), dim=-1)
1288+
k = compiled_cat((k_nope, k_pe.expand(-1, self.num_heads, -1)), dim=-1)
12771289
k = k.view(-1, self.num_heads * self.qk_head_dim)
12781290

12791291
# copy q_lens to replace kv_lens_runtime
@@ -1284,9 +1296,6 @@ def forward_context_with_chunked_prefill(
12841296
num_contexts].sum().item(
12851297
)
12861298

1287-
# out_scale = getattr(self.o_proj, "inv_input_scale", None)
1288-
out_scale = None # Currently we use BF16 MHA for context phase
1289-
12901299
# latent_cache must be None to differentiate from normal context phase,
12911300
# so that we can skip applying RoPE and appending KV cache inside attention op
12921301
temp_attn_output = self.mha.forward(
@@ -1296,7 +1305,7 @@ def forward_context_with_chunked_prefill(
12961305
attn_metadata,
12971306
attention_input_type=AttentionInputType.context_only,
12981307
latent_cache=None,
1299-
out_scale=out_scale,
1308+
out_scale=self.out_scale,
13001309
softmax_stats_tensor=self.temp_softmax_stats_tensor,
13011310
chunked_prefill_buffer_batch_size=attn_metadata.runtime_features.
13021311
chunked_prefill_buffer_batch_size,
@@ -1394,16 +1403,13 @@ def forward_generation(
13941403
self.num_heads * (self.kv_lora_rank + self.qk_rope_head_dim)
13951404
])
13961405

1397-
# out_scale = getattr(self.o_proj, "inv_input_scale", None)
1398-
out_scale = None # Although we use FP8 MLA for generation phase, the output is still in BF16
1399-
14001406
attn_out_latent = self.mqa.forward(
14011407
fused_q,
14021408
None,
14031409
None,
14041410
attn_metadata,
14051411
attention_input_type=AttentionInputType.generation_only,
1406-
out_scale=out_scale,
1412+
out_scale=self.out_scale,
14071413
latent_cache=latent_cache, # kvcache and k_pe
14081414
q_pe=q_pe, # used by `invokeMLARopeGeneration`
14091415
)

0 commit comments

Comments
 (0)