- Overview & Motivation
- Theoretical Background
- Mathematical Formulation
- High-Level Intuition
- Implementation Details
- Code Walkthrough
- Optimization Tricks
- Experiments & Results
- Common Pitfalls
- References
Autoregressive language model generation is inherently sequential: each token depends on all previous tokens. Without optimization, generating each token requires:
- Recomputing attention over all previous tokens
- Memory quadratic in sequence length
- Computational redundancy - we recompute the same Key/Value projections repeatedly
For a sequence of length T, generating token T requires:
- Time: O(T² × d) for naive attention
- Memory: O(T² × d) for full attention matrices
- Redundancy: T-1 forward passes computing the same K/V values
Key insight: In autoregressive generation, the Key and Value projections for past tokens never change. We can cache them and reuse them for subsequent tokens.
Impact:
- Latency: Reduces per-token generation from O(T²) to O(T) compute
- Memory: Trades O(T²) attention matrix for O(L × T × d) cache (L=layers, d=dimension)
- Throughput: Enables batching by reducing memory pressure per sequence
Memory Trade-off:
Without KV Cache: O(T²) attention → Cannot batch
With KV Cache: O(L×T×d) cache → Can batch effectively
Standard multi-head attention:
Q = X @ W_Q # (T, d) @ (d, d_k) = (T, d_k)
K = X @ W_K # (T, d) @ (d, d_k) = (T, d_k)
V = X @ W_V # (T, d) @ (d, d_v) = (T, d_v)
Attention(Q, K, V) = softmax(Q @ K^T / √d_k) @ V
At step t, we have tokens [x₁, x₂, ..., xₜ₋₁] and want to generate xₜ:
Without caching:
# Step t: Process entire sequence again
X_t = [x_1, x_2, ..., x_{t-1}, x_t] # shape: (t, d)
Q_t = X_t @ W_Q # Recompute all queries
K_t = X_t @ W_K # Recompute all keys (REDUNDANT!)
V_t = X_t @ W_V # Recompute all values (REDUNDANT!)With caching:
# Step t: Only compute new token
x_t_embedding # shape: (1, d)
q_t = x_t @ W_Q # Only new query
k_t = x_t @ W_K # Only new key
v_t = x_t @ W_V # Only new value
# Concatenate with cache
K_cached = [k_1, ..., k_{t-1}] # Retrieved from cache
V_cached = [v_1, ..., v_{t-1}] # Retrieved from cache
K_t = concat(K_cached, k_t) # shape: (t, d_k)
V_t = concat(V_cached, v_t) # shape: (t, d_v)Key property: In causal (autoregressive) attention, the Key and Value vectors for position i depend only on token i, not on future tokens.
Attention(q_t, K_{1:t}, V_{1:t}) = softmax(q_t @ K_{1:t}^T) @ V_{1:t}
↑
Only q_t is "new" at step t
K_{1:t-1} and V_{1:t-1} are unchanged
Without KV Cache (Naive Generation)
For sequence length T, at step t:
- Compute: O(t × d²) for projections
- Attention: O(t² × d) for QK^T and attention @ V
- Total per token: O(t × d² + t² × d)
- Total for T tokens: O(T² × d² + T³ × d)
With KV Cache
For sequence length T, at step t:
- Compute: O(d²) for single token projection
- Attention: O(t × d) for q_t @ K_{1:t}^T
- Cache operations: O(L × d_k) read/write
- Total per token: O(d² + t × d)
- Total for T tokens: O(T × d² + T² × d)
Speedup Analysis
Asymptotic speedup for long sequences:
Speedup = (T² × d² + T³ × d) / (T × d² + T² × d)
≈ T (when T >> d)
For typical LLMs (d=4096, T=2048): ~2000x speedup in theory!
In practice: 5-20x speedup due to:
- Memory bandwidth bottlenecks
- Cache management overhead
- Other non-attention operations
KV Cache size per layer:
Memory = 2 × (batch_size × num_heads × seq_len × head_dim) × bytes_per_element
For LLaMA-7B (batch=1, seq=2048, FP16):
= 2 × (1 × 32 × 2048 × 128) × 2 bytes
= 32 MB per layer
For 32 layers:
= 32 × 32 MB = 1 GB total
Scaling with batch size:
KV_cache_size = 2 × L × B × H × T × d × bytes
Where:
L = num_layers
B = batch_size
H = num_heads
T = max_seq_len
d = head_dim
bytes = 2 (FP16) or 1 (INT8) or 0.5 (INT4)
Step 1: Generate token 1
┌─────────────────────────────────────┐
│ Input: [BOS] │
│ │
│ Compute: Q₁, K₁, V₁ │
│ │
│ Cache: K₁ → [K₁] │
│ V₁ → [V₁] │
│ │
│ Output: token₁ │
└─────────────────────────────────────┘
Step 2: Generate token 2
┌─────────────────────────────────────┐
│ Input: token₁ │
│ │
│ Compute: Q₂, K₂, V₂ ← Only new! │
│ │
│ Cache: [K₁, K₂] ← Append │
│ [V₁, V₂] │
│ │
│ Attention: Q₂ @ [K₁, K₂]ᵀ │
│ │
│ Output: token₂ │
└─────────────────────────────────────┘
Step t: Generate token t
┌─────────────────────────────────────┐
│ Input: token_{t-1} │
│ │
│ Compute: Q_t, K_t, V_t │
│ │
│ Cache: [K₁...K_{t-1}, K_t] │
│ [V₁...V_{t-1}, V_t] │
│ ↑ │
│ Retrieved from cache │
│ │
│ Attention: Q_t @ K_{1:t}ᵀ │
│ │
│ Output: token_t │
└─────────────────────────────────────┘
┌──────────────┐
│ Allocation │ Pre-allocate fixed-size buffers
│ │ Shape: (batch, heads, max_len, dim)
└──────┬───────┘
│
▼
┌──────────────┐
│ Prefill │ Process prompt tokens
│ │ Fill cache with K/V for prompt
└──────┬───────┘
│
▼
┌──────────────┐
│ Generation │ Autoregressive token generation
│ │ - Compute K_t, V_t for new token
│ │ - Append to cache at position t
│ │ - Attend over K_{1:t}, V_{1:t}
└──────┬───────┘
│
▼
┌──────────────┐
│ Cleanup │ Free or reuse cache buffers
└──────────────┘
Contiguous allocation vs Dynamic allocation:
# Contiguous (better for GPU)
cache_shape = (batch, num_heads, max_seq_len, head_dim)
k_cache = torch.zeros(cache_shape, dtype=torch.float16, device='cuda')
# Pros: Fast indexing, good memory coalescing
# Cons: Wastes memory for variable-length sequencesChallenge: Different sequences in batch have different lengths
# Attention mask handles variable lengths
attention_mask[batch_idx, :seq_len[batch_idx]] = True
attention_mask[batch_idx, seq_len[batch_idx]:] = False
# Only valid positions attend
scores = Q @ K^T # (batch, heads, 1, seq_len)
scores.masked_fill_(~attention_mask, float('-inf'))Each layer has independent K/V caches:
# List of caches (one per layer)
self.k_cache = [
torch.zeros(cache_shape, ...)
for _ in range(num_layers)
]Track current sequence length per batch item:
self.seq_lens = torch.zeros(batch_size, dtype=torch.long)
# After generating token t
self.seq_lens += 1 # All sequences advance by 1class StaticKVCache:
def __init__(self, max_seq_len, ...):
self.cache = torch.zeros(
(batch, heads, max_seq_len, dim)
)
self.current_len = 0
def append(self, k, v):
self.cache[:, :, self.current_len, :] = k
self.current_len += 1Pros: Simple, fast indexing Cons: Memory waste for short sequences
class DynamicKVCache:
def __init__(self):
self.cache = []
def append(self, k, v):
self.cache.append((k, v))
def get_full(self):
return torch.cat([kv[0] for kv in self.cache], dim=2)Pros: No memory waste Cons: Slower concatenation, fragmentation
class KVCache:
"""Simple KV cache for single-sequence generation."""
def __init__(
self,
num_layers: int,
max_batch_size: int,
max_seq_len: int,
num_heads: int,
head_dim: int,
dtype: torch.dtype = torch.float16,
device: torch.device = torch.device('cuda')
):
self.num_layers = num_layers
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len
self.num_heads = num_heads
self.head_dim = head_dim
# Pre-allocate cache tensors
cache_shape = (max_batch_size, num_heads, max_seq_len, head_dim)
self.k_cache = [
torch.zeros(cache_shape, dtype=dtype, device=device)
for _ in range(num_layers)
]
self.v_cache = [
torch.zeros(cache_shape, dtype=dtype, device=device)
for _ in range(num_layers)
]
# Track sequence lengths
self.seq_lens = torch.zeros(max_batch_size, dtype=torch.long, device=device)def update(
self,
layer_idx: int,
key: torch.Tensor, # (batch, heads, seq_len, dim)
value: torch.Tensor,
start_pos: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Update cache with new K/V and return full cached tensors."""
batch_size, _, seq_len, _ = key.shape
if start_pos is None:
start_pos = self.seq_lens[0].item()
# Write new K/V to cache at current position
end_pos = start_pos + seq_len
self.k_cache[layer_idx][:batch_size, :, start_pos:end_pos, :] = key
self.v_cache[layer_idx][:batch_size, :, start_pos:end_pos, :] = value
# Update sequence lengths
self.seq_lens[:batch_size] = end_pos
# Return full cached K/V (from start to current position)
return (
self.k_cache[layer_idx][:batch_size, :, :end_pos, :],
self.v_cache[layer_idx][:batch_size, :, :end_pos, :]
)def attention_with_cache(
self,
query: torch.Tensor, # (batch, heads, 1, dim) - new token
key: torch.Tensor, # (batch, heads, 1, dim) - new token
value: torch.Tensor, # (batch, heads, 1, dim) - new token
layer_idx: int,
cache: KVCache
) -> torch.Tensor:
"""Attention using KV cache."""
# Update cache and get full K/V including history
full_k, full_v = cache.update(layer_idx, key, value)
# full_k: (batch, heads, seq_len, dim)
# full_v: (batch, heads, seq_len, dim)
# Compute attention over full history
# Q: (batch, heads, 1, dim)
# K: (batch, heads, seq_len, dim)
scores = query @ full_k.transpose(-2, -1) / math.sqrt(self.head_dim)
# scores: (batch, heads, 1, seq_len)
attn_weights = F.softmax(scores, dim=-1)
output = attn_weights @ full_v
# output: (batch, heads, 1, dim)
return output@torch.no_grad()
def generate_with_cache(
model: nn.Module,
input_ids: torch.Tensor, # (batch, prompt_len)
max_new_tokens: int = 100,
cache: Optional[KVCache] = None
) -> torch.Tensor:
"""Generate tokens using KV cache."""
batch_size, prompt_len = input_ids.shape
# Initialize cache if needed
if cache is None:
cache = KVCache(
num_layers=model.config.num_layers,
max_batch_size=batch_size,
max_seq_len=prompt_len + max_new_tokens,
num_heads=model.config.num_heads,
head_dim=model.config.head_dim
)
# Prefill phase: process entire prompt at once
outputs = model(input_ids, use_cache=True, cache=cache)
next_token_logits = outputs.logits[:, -1, :]
next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True)
generated = torch.cat([input_ids, next_tokens], dim=1)
# Generation phase: one token at a time
for _ in range(max_new_tokens - 1):
# Only process the last token (KV for previous tokens cached)
outputs = model(
next_tokens, # (batch, 1) - only new token!
use_cache=True,
cache=cache
)
next_token_logits = outputs.logits[:, -1, :]
next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True)
generated = torch.cat([generated, next_tokens], dim=1)
# Check for EOS
if (next_tokens == eos_token_id).all():
break
return generatedProcess the entire prompt in a single forward pass:
# Inefficient: one token at a time for prompt
for i in range(prompt_len):
outputs = model(input_ids[:, i:i+1], cache=cache)
# Efficient: entire prompt at once
outputs = model(input_ids, cache=cache) # Parallel processing!Why: Prompt processing is compute-bound (can parallelize), while generation is memory-bound.
Minimize wasted memory by:
# Left-padding for decoder-only models
# Aligns sequence ends for efficient generation
input_ids = pad_sequence(sequences, batch_first=True, padding_value=pad_token_id)
# Track actual lengths
attention_mask = (input_ids != pad_token_id)For repeated prefixes (e.g., system prompts):
# Cache system prompt once
system_prompt_cache = cache_prompt(system_prompt)
# Reuse for all user requests
for user_input in user_requests:
cache = copy_cache(system_prompt_cache) # Start from cached state
generate(user_input, cache=cache)See Prefix Caching for advanced techniques.
For sequences > GPU memory:
# Store cache on CPU, stream to GPU as needed
cache = MemoryMappedKVCache(
storage='cpu',
dtype=torch.float16
)
# Only active portion on GPU
active_window = cache.get_window(start=max(0, t-window), end=t)Avoid allocations during generation:
# Bad: Creates new tensor every step
cache = torch.cat([cache, new_kv], dim=2)
# Good: Pre-allocated, in-place update
cache[:, :, current_pos, :] = new_kv
current_pos += 1- Model: LLaMA-7B (32 layers, 32 heads, head_dim=128)
- Hardware: NVIDIA A100 80GB
- Sequences: 1024 tokens (128 prompt + 896 generated)
- Batch sizes: 1, 8, 32
| Batch Size | Without Cache (ms/token) | With Cache (ms/token) | Speedup |
|---|---|---|---|
| 1 | 2847 | 28 | 101.7x |
| 8 | 22776 | 187 | 121.8x |
| 32 | OOM | 731 | ∞ |
Key findings:
- Single sequence: 100x+ speedup
- Speedup increases with longer sequences (more redundant computation)
- Enables much larger batch sizes (32x vs OOM)
| Component | Without Cache | With Cache | Overhead |
|---|---|---|---|
| Model weights | 13.5 GB | 13.5 GB | - |
| Activations | 0.5 GB | 0.5 GB | - |
| Attention | 4.2 GB | - | -4.2 GB |
| KV Cache | - | 1.0 GB | +1.0 GB |
| Total | 18.2 GB | 15.0 GB | -17.6% |
Batch size 32:
- With cache: 15 GB + 32 GB (cache) = 47 GB ✓
- Without cache: 18 GB + 134 GB (attention) = 152 GB ✗ OOM
Batch Size | Without Cache | With Cache | Improvement
-----------|---------------|------------|------------
1 | 0.35 | 35.7 | 102x
8 | 0.35 | 286 | 817x
32 | OOM | 1144 | ∞
Key insight: KV cache enables batching, which dramatically improves throughput.
Sequence Length | KV Cache Size (GB) | Attention Matrix (GB)
----------------|-------------------|---------------------
512 | 0.5 | 1.0
1024 | 1.0 | 4.2
2048 | 2.0 | 16.8
4096 | 4.0 | 67.1
Cache grows linearly, attention grows quadratically.
# WRONG: Cache contains data from previous sequence
for prompt in prompts:
generate(prompt, cache=shared_cache) # Contaminated!
# CORRECT: Reset cache for each sequence
for prompt in prompts:
cache.clear() # or create new cache
generate(prompt, cache=cache)# WRONG: Position gets out of sync
cache.update(layer_idx=0, key=k, value=v, start_pos=self.pos)
self.pos += 2 # Oops! Only added 1 token
# CORRECT: Track accurately
cache.update(layer_idx=0, key=k, value=v, start_pos=self.pos)
self.pos += key.shape[2] # Actual number of tokens added# WRONG: Cache expects (batch, heads, seq, dim)
key = key.transpose(1, 2) # Wrong shape passed to cache
cache.update(layer_idx=0, key=key, value=value)
# CORRECT: Ensure consistent shape convention
assert key.shape == (batch, heads, seq_len, dim)
cache.update(layer_idx=0, key=key, value=value)# WRONG: All sequences forced to same length
cache.seq_lens[:] = max_len # Shorter sequences attend to garbage
# CORRECT: Track per-sequence lengths
cache.seq_lens[batch_idx] = actual_length[batch_idx]# WRONG: Cache keeps growing indefinitely
cache.append(new_kv) # No limit!
# CORRECT: Implement cache eviction
if cache.size() > max_size:
cache.evict_oldest() # or use sliding window# WRONG: Single cache shared across layers
for layer in model.layers:
output = layer(x, cache=global_cache) # Overwrites previous layer!
# CORRECT: Separate cache per layer
for layer_idx, layer in enumerate(model.layers):
output = layer(x, cache=cache, layer_idx=layer_idx)-
Attention Is All You Need (Vaswani et al., 2017)
- Original Transformer paper
- https://arxiv.org/abs/1706.03762
-
Generating Long Sequences with Sparse Transformers (Child et al., 2019)
- Early discussion of KV caching patterns
- https://arxiv.org/abs/1904.10509
-
Fast Transformer Decoding: One Write-Head is All You Need (Shazeer, 2019)
- Multi-Query Attention reduces KV cache size
- https://arxiv.org/abs/1911.02150
-
GQA: Training Generalized Multi-Query Transformer Models (Ainslie et al., 2023)
- Grouped-Query Attention for KV cache reduction
- https://arxiv.org/abs/2305.13245
- HuggingFace Transformers:
transformers/cache_utils.py - vLLM: Memory management with KV cache
- TensorRT-LLM: Optimized KV cache kernels
- PagedAttention - Virtual memory for KV caches
- Quantized KV Cache - Compress cache with quantization
- Prefix Caching - Share cache across requests
- Continuous Batching - Dynamic batching with caching
- Learn PagedAttention: Eliminate memory fragmentation → 02_paged_attention.md
- Reduce memory 2-4x: Quantize the cache → 03_quantized_kv_cache.md
- Reuse computations: Implement prefix caching → 04_prefix_caching.md