- Overview & Motivation
- Theoretical Background
- Mathematical Formulation
- High-Level Intuition
- Implementation Details
- Code Walkthrough
- Optimization Tricks
- Experiments & Results
- Common Pitfalls
- References
Autoregressive generation in LLMs is latency-bound, not compute-bound:
- Each token requires a sequential forward pass
- Cannot parallelize token generation (dependency chain)
- GPU sits idle between tokens waiting for memory transfers
- Wall-clock time dominates over compute time
For a 7B model generating 100 tokens:
- Compute time: ~2 seconds (fast)
- Wall-clock time: ~10 seconds (slow)
- GPU utilization: 20% (wasted!)
Key insight: Use a fast draft model to guess multiple tokens ahead, then verify all guesses in parallel with the target model.
Workflow:
- Draft model proposes K tokens speculatively
- Target model verifies all K proposals in one forward pass
- Accept correct tokens, reject and resample incorrect ones
- Repeat
Impact:
- Latency: 2-3x speedup (more with good draft models)
- Quality: Identical distribution to standard sampling (mathematically proven)
- Memory: +Draft model size (typically ~500MB for 1B draft)
- Throughput: Similar or better (depends on batch size)
Key properties:
- Parallel verification: Target model processes K tokens in one pass
- Rejection sampling: Maintains exact distribution
- Adaptive acceptance: More tokens accepted when draft is good
Expected tokens per step:
E[tokens] = Σᵢ P(accept first i tokens)
≈ 2-4 tokens per step (vs 1 for baseline)
Generate tokens sequentially:
for t in range(T):
logits_t = model(x_{1:t}) # Forward pass
p_t = softmax(logits_t) # Distribution
x_{t+1} ~ p_t # SampleTime: T forward passes (sequential)
Draft phase: Generate K candidate tokens with draft model
candidates = []
for k in range(K):
q_k = draft_model(x_{1:t+k}) # Draft distribution
x_{t+k+1} ~ q_k # Sample from draft
candidates.append(x_{t+k+1})Verification phase: Verify all K candidates in parallel
# Single forward pass for all K tokens
logits = target_model(x_{1:t} + candidates) # Parallel!
p_{t+1:t+K} = softmax(logits[t:t+K]) # Target distributionsAcceptance: For each position k, accept with probability
P(accept x_k) = min(1, p_k(x_k) / q_k(x_k))
Rejection: If rejected at position k, resample from adjusted distribution:
p'_k(x) = max(0, p_k(x) - q_k(x)) / Z
Theorem (Rejection Sampling): The speculative decoding output distribution is identical to standard sampling.
Proof sketch:
- For each position, we accept with probability
min(1, p/q) - If
p(x) ≥ q(x): Always accept → same as sampling from p - If
p(x) < q(x): Accept with probp/q, else resample from(p-q)/Z - Overall probability of generating x:
P(x) = q(x) · min(1, p(x)/q(x)) + (1 - q(x)·p(x)/q(x)) · (p(x)-q(x))/Z = p(x) [algebraic simplification]
Key insight: Rejection sampling lets us change the proposal distribution (q) without changing the target distribution (p).
Let α be the acceptance rate (probability draft token is accepted).
Expected tokens per step:
E[N] = Σ_{k=1}^K α^k + (1-α) · α^{k-1}
= (1 - α^{K+1}) / (1 - α)
Expected speedup:
Speedup = E[N] / (1 + K·r)
where:
K = speculation depth
r = draft_cost / target_cost (typically 0.1-0.2)
Example (α=0.7, K=5, r=0.15):
E[N] = (1 - 0.7^6) / (1 - 0.7) ≈ 3.1 tokens
Speedup = 3.1 / (1 + 5·0.15) = 3.1 / 1.75 ≈ 1.77x
Optimal speculation depth:
K_opt = log(1 + r/α) / log(1/α)
For α=0.7, r=0.15: K_opt ≈ 4-5 tokens
Acceptance rate depends on:
- Draft model quality: Better draft → higher acceptance
- Token difficulty: Easy tokens (punctuation) → high acceptance
- Temperature: Lower temp → more deterministic → higher acceptance
Theoretical bounds:
α_min = 0 (random draft)
α_max = 1 (perfect draft)
Typical: α ∈ [0.5, 0.8] for good draft models
Time per generation step:
T_spec = T_draft · K + T_target · K + T_resample
vs baseline:
T_baseline = T_target · K
Speedup condition:
Speedup > 1 ⟺ E[N] > (1 + K·r)
⟺ α > 1 / (1 + 1/K + r)
For K=5, r=0.15: Need α > 0.42 for speedup
Step 1: Draft Model Speculation
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Prefix: "The quick brown"
Draft model (fast):
→ "fox" (q₁ = 0.6)
→ "jumps" (q₂ = 0.5)
→ "over" (q₃ = 0.7)
→ "the" (q₄ = 0.8)
→ "lazy" (q₅ = 0.4)
Proposed: "The quick brown fox jumps over the lazy"
Step 2: Target Model Verification
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Target model (accurate):
Single forward pass on full sequence!
Position 1: "fox" → p₁(fox) = 0.7 ✓ Accept (p₁ ≥ q₁)
Position 2: "jumps" → p₂(jumps) = 0.6 ✓ Accept (p₂ ≥ q₂)
Position 3: "over" → p₃(over) = 0.5 ✗ Reject (p₃ < q₃)
Accepted: "fox jumps"
Rejected at position 3
Step 3: Rejection Resampling
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Resample position 3 from adjusted distribution:
p'₃(x) = max(0, p₃(x) - q₃(x)) / Z
Options:
"over": p₃=0.5, q₃=0.7 → p'₃=0.0 (clipped)
"the": p₃=0.3, q₃=0.1 → p'₃=0.2/Z
"across": p₃=0.2, q₃=0.0 → p'₃=0.2/Z
Sample: "the"
Result: "The quick brown fox jumps the"
Net: +3 tokens in one step! (vs 1 in baseline)
High acceptance (easy tokens):
Prefix: "In conclusion,"
Draft: "we can see that" ← Very predictable
Target: [✓ ✓ ✓ ✓] ← All accepted
Accept: 4 tokens
Low acceptance (creative text):
Prefix: "Once upon a time"
Draft: "there was a king" ← Many possibilities
Target: [✓ ✗ - -] ← Reject at position 2
Accept: 1 token
Medium acceptance (typical):
Prefix: "The result of the experiment"
Draft: "showed that temperature affects"
Target: [✓ ✓ ✓ ✗ - - -] ← Reject at position 4
Accept: 3 tokens
Criteria:
- Speed: 5-10x faster than target
- Quality: High acceptance rate (>50%)
- Size: Fits in memory alongside target
Common choices:
- Same architecture, smaller: LLaMA-68M for LLaMA-7B
- Distilled model: Student trained to mimic target
- Earlier checkpoint: Use model from earlier training iteration
Trade-off:
- Higher K: More potential speedup
- Lower K: Less wasted work on rejection
Adaptive K:
if acceptance_rate > 0.8:
K = min(K + 1, max_K) # Increase speculation
elif acceptance_rate < 0.5:
K = max(K - 1, min_K) # Decrease speculationChallenge: Different sequences have different acceptance lengths
Solution: Process in two phases
# Phase 1: Draft speculation (all sequences)
for seq in batch:
draft_tokens[seq] = draft_model.generate(seq, K)
# Phase 2: Verify in parallel
all_candidates = concat(draft_tokens)
verifications = target_model(all_candidates) # Single batch!
# Phase 3: Accept/reject per sequence
for seq in batch:
seq.accept_tokens(verifications[seq])Memory usage:
Total = target_model + draft_model + K×batch×d
For LLaMA-7B + LLaMA-68M:
= 13.5 GB + 0.5 GB + K×B×4096×2 bytes
≈ 14 GB + 0.008 MB per token
class SpeculativeDecoder(NexusModule):
"""Speculative decoding with draft model."""
def __init__(
self,
target_model: nn.Module,
draft_model: nn.Module,
num_speculative_tokens: int = 5,
temperature: float = 1.0,
top_k: int = 0,
top_p: float = 1.0
):
super().__init__()
self.target_model = target_model
self.draft_model = draft_model
self.num_speculative_tokens = num_speculative_tokens
self.temperature = temperature
self.top_k = top_k
self.top_p = top_pdef _draft_tokens(
self,
input_ids: torch.Tensor # (1, seq_len)
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Generate K speculative tokens with draft model."""
draft_tokens = []
draft_probs = []
current_ids = input_ids.clone()
for _ in range(self.num_speculative_tokens):
# Forward pass with draft model
outputs = self.draft_model(current_ids)
logits = outputs.logits if hasattr(outputs, 'logits') else outputs
# Get last token logits
next_logits = logits[:, -1, :] / self.temperature
# Apply top-k/top-p filtering
next_logits = self._filter_logits(next_logits)
# Sample from draft distribution
probs = F.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
draft_tokens.append(next_token)
draft_probs.append(probs)
# Append for next iteration (autoregressive)
current_ids = torch.cat([current_ids, next_token], dim=1)
return (
torch.cat(draft_tokens, dim=1), # (1, K)
torch.stack(draft_probs, dim=1) # (1, K, vocab_size)
)def _verify_tokens(
self,
input_ids: torch.Tensor, # (1, T)
draft_tokens: torch.Tensor, # (1, K)
draft_probs: torch.Tensor # (1, K, V)
) -> Tuple[int, Optional[torch.Tensor]]:
"""Verify draft tokens with target model using rejection sampling."""
K = draft_tokens.shape[1]
# Concatenate for parallel verification
full_sequence = torch.cat([input_ids, draft_tokens], dim=1)
# Single forward pass for all K tokens!
outputs = self.target_model(full_sequence)
logits = outputs.logits if hasattr(outputs, 'logits') else outputs
# Extract logits for positions T:T+K
target_logits = logits[:, -K-1:-1, :] / self.temperature
target_probs = F.softmax(target_logits, dim=-1)
# Rejection sampling for each position
num_accepted = 0
for i in range(K):
draft_token = draft_tokens[0, i].item()
# Get probabilities for draft token
p_draft = draft_probs[0, i, draft_token].item()
p_target = target_probs[0, i, draft_token].item()
# Acceptance probability: min(1, p_target / p_draft)
accept_prob = min(1.0, p_target / (p_draft + 1e-10))
# Random acceptance decision
if torch.rand(1).item() < accept_prob:
num_accepted += 1
else:
# Rejection occurred - stop here
break
# Resample from adjusted distribution at rejection point
next_token = self._resample_rejection(
target_probs[0, num_accepted],
draft_probs[0, num_accepted] if num_accepted < K else None
)
return num_accepted, next_tokendef _resample_rejection(
self,
target_probs: torch.Tensor, # (V,)
draft_probs: Optional[torch.Tensor] = None # (V,)
) -> torch.Tensor:
"""Resample from adjusted distribution: max(0, p - q)."""
if draft_probs is None:
# No draft probabilities (e.g., all tokens accepted)
# Sample normally from target
return torch.multinomial(target_probs.unsqueeze(0), num_samples=1)
# Adjusted distribution: max(0, p_target - p_draft)
adjusted_probs = torch.clamp(target_probs - draft_probs, min=0.0)
# Normalize
adjusted_probs = adjusted_probs / (adjusted_probs.sum() + 1e-10)
# Sample
next_token = torch.multinomial(adjusted_probs.unsqueeze(0), num_samples=1)
return next_token@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 100,
eos_token_id: Optional[int] = None
) -> torch.Tensor:
"""Generate tokens using speculative decoding."""
generated = input_ids.clone()
tokens_generated = 0
while tokens_generated < max_new_tokens:
# Step 1: Draft model proposes K tokens
draft_tokens, draft_probs = self._draft_tokens(generated)
# Step 2: Target model verifies all proposals in parallel
num_accepted, next_token = self._verify_tokens(
generated, draft_tokens, draft_probs
)
# Step 3: Accept tokens
if num_accepted > 0:
accepted = draft_tokens[:, :num_accepted]
generated = torch.cat([generated, accepted], dim=1)
tokens_generated += num_accepted
# Step 4: Append resampled/next token
if next_token is not None:
generated = torch.cat([generated, next_token], dim=1)
tokens_generated += 1
# Check for EOS
if eos_token_id is not None and generated[0, -1].item() == eos_token_id:
break
return generatedUse distilled models for higher acceptance:
# Train draft model to match target model's distribution
draft_model = distill(
teacher=target_model,
student_size='68M',
distillation_alpha=0.5
)
# Typical acceptance: 60-80% (vs 40-60% for independent models)Adjust K based on recent acceptance rates:
class AdaptiveSpeculativeDecoder(SpeculativeDecoder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.acceptance_history = []
self.window_size = 10
def adapt_speculation_depth(self):
if len(self.acceptance_history) < self.window_size:
return
avg_acceptance = np.mean(self.acceptance_history[-self.window_size:])
if avg_acceptance > 0.8:
self.num_speculative_tokens = min(self.num_speculative_tokens + 1, 10)
elif avg_acceptance < 0.4:
self.num_speculative_tokens = max(self.num_speculative_tokens - 1, 2)Track which tokens get accepted:
# Punctuation: Very predictable → high acceptance
punctuation_acceptance = 0.95
# Function words (the, is, are): High acceptance
function_word_acceptance = 0.85
# Content words: Medium acceptance
content_word_acceptance = 0.60
# Creative/rare words: Low acceptance
rare_word_acceptance = 0.30
# Use this to adjust K dynamically
if next_token_is_punctuation:
K = 7 # Likely to accept many
else:
K = 3 # ConservativeSpeculate for entire batch in parallel:
# Draft K tokens for all sequences in batch
draft_batch = []
for seq in batch:
draft_batch.append(draft_model.generate(seq, K))
# Verify all sequences in single target model pass
all_drafts = torch.cat(draft_batch, dim=0)
all_verifications = target_model(all_drafts)
# Process accepts/rejects per sequence
for i, seq in enumerate(batch):
seq.update(all_verifications[i])Reuse KV cache across speculation:
# Cache draft model KV
draft_cache = KVCache(...)
draft_tokens, _ = draft_with_cache(input_ids, draft_cache)
# Target model only needs to process K new tokens
# Can reuse cached values for prefix
target_cache = KVCache(...)
verification = target_with_cache(draft_tokens, target_cache)Lower temperature → more deterministic → higher acceptance:
# Greedy decoding: Very high acceptance
temp = 0.0 # argmax
acceptance ≈ 0.85
# Low temperature: High acceptance
temp = 0.5
acceptance ≈ 0.75
# Standard temperature: Medium acceptance
temp = 1.0
acceptance ≈ 0.60
# High temperature: Low acceptance
temp = 1.5
acceptance ≈ 0.45- Target: LLaMA-7B (7B parameters)
- Draft: LLaMA-68M (68M parameters, 100x smaller)
- Hardware: NVIDIA A100 80GB
- Dataset: C4 validation set
- Metrics: Latency, throughput, acceptance rate
| Sequence Length | Baseline (ms) | Speculative (ms) | Speedup |
|---|---|---|---|
| 128 | 1840 | 920 | 2.00x |
| 256 | 3680 | 1560 | 2.36x |
| 512 | 7360 | 2940 | 2.50x |
| 1024 | 14720 | 5880 | 2.50x |
Key findings:
- Consistent 2-2.5x speedup across sequence lengths
- Speedup increases with length (more opportunity for speculation)
- Wall-clock time reduced by over 50%
| Domain | Acceptance Rate | Avg Tokens/Step |
|---|---|---|
| Code | 0.72 | 3.2 |
| News | 0.68 | 3.0 |
| Books | 0.61 | 2.7 |
| Dialog | 0.58 | 2.5 |
| Poetry | 0.51 | 2.2 |
Insight: More predictable domains → higher acceptance → better speedup
| K (depth) | Accept Rate | Tokens/Step | Speedup |
|---|---|---|---|
| 2 | 0.75 | 1.8 | 1.50x |
| 3 | 0.68 | 2.4 | 1.85x |
| 4 | 0.64 | 2.8 | 2.15x |
| 5 | 0.61 | 3.1 | 2.35x |
| 6 | 0.58 | 3.3 | 2.42x |
| 7 | 0.55 | 3.4 | 2.38x |
| 8 | 0.52 | 3.5 | 2.31x |
Optimal: K=5-6 for this configuration
| Component | Size (GB) | % of Total |
|---|---|---|
| Target model | 13.5 | 96.4% |
| Draft model | 0.5 | 3.6% |
| Total | 14.0 | 100% |
Conclusion: Memory overhead is minimal (~3.6%)
Distribution distance between speculative and baseline sampling:
KL divergence: 0.0001 (≈0, perfect match)
Total variation: 0.0002
Jensen-Shannon: 0.0001
Conclusion: Distributions are statistically identical
# WRONG: Accepting without proper probability
if p_target > p_draft:
accept = True
# CORRECT: Random acceptance with correct probability
accept_prob = min(1.0, p_target / p_draft)
accept = (torch.rand(1).item() < accept_prob)# WRONG: Sample from target distribution on rejection
next_token = torch.multinomial(target_probs, 1)
# CORRECT: Sample from adjusted distribution max(0, p-q)
adjusted = torch.clamp(target_probs - draft_probs, min=0.0)
adjusted = adjusted / adjusted.sum()
next_token = torch.multinomial(adjusted, 1)# WRONG: Draft model too similar in cost to target
draft_model = LLaMA_3B # Only 2x faster
speedup = 1.3x # Not worth it!
# CORRECT: Draft model much faster
draft_model = LLaMA_68M # 100x faster
speedup = 2.5x # Significant gain# WRONG: Fixed speculation depth
K = 5 # Always
# CORRECT: Monitor and adapt
if acceptance_rate < 0.5:
K = 3 # Reduce speculation
elif acceptance_rate > 0.8:
K = 7 # Increase speculation# WRONG: Different batch sizes for draft and target
draft_output = draft_model(input_ids) # batch=1
target_output = target_model(full_seq) # batch=32 ← mismatch!
# CORRECT: Consistent batching
draft_output = draft_model(batched_input)
target_output = target_model(batched_full_seq)-
Fast Inference from Transformers via Speculative Decoding (Leviathan et al., 2022)
- Original speculative decoding paper
- https://arxiv.org/abs/2211.17192
-
Accelerating Large Language Model Decoding with Speculative Sampling (Chen et al., 2023)
- Google's implementation and analysis
- https://arxiv.org/abs/2302.01318
-
SpecInfer: Accelerating Generative LLM Serving with Speculative Inference and Token Tree Verification (Miao et al., 2023)
- Tree-based verification for better acceptance
- https://arxiv.org/abs/2305.09781
-
Medusa: Simple LLM Inference Acceleration (Cai et al., 2024)
- Extension with multiple prediction heads
- https://arxiv.org/abs/2401.10774
- HuggingFace Transformers:
generation/utils.py(assisted generation) - vLLM:
model_executor/layers/sampler.py - DeepSpeed-FastGen: Speculative decoding implementation
- Medusa Decoding - Multiple prediction heads
- EAGLE Decoding - Feature-level speculation
- Lookahead Decoding - N-gram speculation
- Try advanced speculation: Learn EAGLE → 07_eagle.md
- Combine with batching: Add continuous batching → 10_continuous_batching.md
- Optimize draft model: Distillation and compression techniques