Multi-token prediction (MTP) is a training and inference technique that enables language models to predict multiple future tokens simultaneously, achieving 2-3x speedup while improving model quality.
- Introduction and Motivation
- Theoretical Foundation
- Mathematical Formulation
- Architecture Design
- Implementation Details
- Training Strategy
- Inference Optimization
- Performance Analysis
- Integration with Serving Systems
- Benchmarks and Results
Standard autoregressive language models generate text one token at a time:
Traditional: t₀:"The" → t₁:"cat" → t₂:"sat" → t₃:"on"
Problem: Sequential dependency limits parallelism
MTP trains the model to predict multiple future tokens from each position:
Position 0: predict ["cat", "sat", "on", "the"]
Position 1: predict ["sat", "on", "the", "mat"]
Key Benefits:
- 2-3x faster inference
- Better model quality (+1-2% on benchmarks)
- No separate draft model needed
- Self-drafting mechanism
- Meta 2024: First showed MTP improves both speed and quality
- Relation to Speculative Decoding: MTP serves as self-drafting
- Training Benefit: Auxiliary objective improves representations
Standard LM: L = -E[log P(x_t | x_<t)]
Multi-token: L_MTP = -E[Σᵢ wᵢ · log P(x_{t+i} | x_<t)]
1. Richer Gradient Signal: Multiple prediction heads provide more learning signal
2. Forced Planning: To predict t+4, model must represent t+1, t+2, t+3
3. Regularization: Auxiliary objectives prevent overfitting
| Method | Draft Model | Memory | Training | Quality |
|---|---|---|---|---|
| MTP | Self-draft | +2% | Required | +1.5% |
| Speculative | Separate | +50% | None | 0% |
| Medusa | Self-draft | +1% | Fine-tune | 0% |
Base model: h_t = Transformer(x₁,...,xₜ; θ)
Prediction heads: For i ∈ [1,n]: P(x_{t+i}|x_<t) = Softmax(Head_i(h_t))
L_total = Σₜ Σᵢ wᵢ · CrossEntropy(Head_i(h_t), x_{t+i})
Common weight schedules:
- Uniform:
wᵢ = 1/n - Inverse distance:
wᵢ = 1/i - Exponential:
wᵢ = α^(i-1)
for t in range(max_length):
# Generate candidates
h_t = Model(x_≤t)
candidates = [Head_i(h_t) for i in range(n)]
# Verify with model
full_logits = Model(x_≤t ⊕ candidates)
accepted = VerifyTokens(full_logits, candidates)
# Accept or fallback
if accepted:
x_≤t = x_≤t ⊕ accepted
else:
x_{t+1} = Sample(full_logits[t])Accept token ŷ_{t+i} if: P_model(ŷ_{t+i} | x_≤{t+i-1}) > τ
Typical threshold: τ = 0.3-0.6
class IndependentHeads(nn.Module):
def __init__(self, dim, vocab_size, num_heads):
self.heads = nn.ModuleList([
nn.Linear(dim, vocab_size)
for _ in range(num_heads)
])
def forward(self, h):
return [head(h) for head in self.heads]Pros: Fast, minimal parameters (~0.1% overhead) Cons: Limited expressiveness
class SharedTrunkHeads(nn.Module):
def __init__(self, dim, vocab_size, num_heads, trunk_dim):
self.trunk = nn.Linear(dim, trunk_dim)
self.heads = nn.ModuleList([
nn.Linear(trunk_dim, vocab_size)
for _ in range(num_heads)
])
def forward(self, h):
trunk_out = self.trunk(h)
return [head(trunk_out) for head in self.heads]Pros: Better features, fewer parameters Cons: Slight compute overhead
class MLPHeads(nn.Module):
def __init__(self, dim, vocab_size, num_heads):
self.heads = nn.ModuleList()
for _ in range(num_heads):
self.heads.append(nn.Sequential(
nn.Linear(dim, dim),
nn.SiLU(),
nn.Linear(dim, vocab_size)
))Pros: Higher capacity Cons: More parameters, slower
For 7B model with 4 heads:
- Independent: 524M params (7.5%)
- Shared trunk: 135M params (1.9%)
From /nexus/components/inference/multi_token.py:
class MultiTokenPredictionHead(NexusModule):
def __init__(
self,
dim: int,
vocab_size: int,
num_future_tokens: int = 4,
shared_trunk: bool = False,
trunk_dim: Optional[int] = None
):
super().__init__()
if shared_trunk:
trunk_dim = trunk_dim or dim
self.trunk = nn.Linear(dim, trunk_dim)
head_input_dim = trunk_dim
else:
self.trunk = None
head_input_dim = dim
self.heads = nn.ModuleList([
nn.Linear(head_input_dim, vocab_size)
for _ in range(num_future_tokens)
])
def forward(self, hidden_states, return_all=True):
if self.trunk is not None:
hidden_states = self.trunk(hidden_states)
if return_all:
all_logits = [head(hidden_states) for head in self.heads]
return torch.stack(all_logits, dim=2)
else:
return self.heads[0](hidden_states)def compute_loss(self, hidden_states, labels, weights=None):
batch_size, seq_len, _ = hidden_states.shape
if weights is None:
weights = [1.0 / (i + 1) for i in range(self.num_future_tokens)]
total_loss = 0.0
total_weight = sum(weights)
for i, (head, weight) in enumerate(zip(self.heads, weights)):
if i >= seq_len - 1:
break
logits = head(hidden_states[:, :-i-1])
targets = labels[:, i+1:]
loss = F.cross_entropy(
logits.view(-1, self.vocab_size),
targets.reshape(-1),
ignore_index=-100
)
total_loss += weight * loss
return total_loss / total_weight@torch.no_grad()
def generate_with_mtp(
model, mtp_heads, input_ids,
max_new_tokens=100,
acceptance_threshold=0.5
):
generated = input_ids.clone()
tokens_generated = 0
while tokens_generated < max_new_tokens:
# Get hidden states
output = model(generated, output_hidden_states=True)
hidden = output.hidden_states[-1][:, -1:, :]
# Generate candidates
candidates = []
for head in mtp_heads.heads:
logits = head(hidden)
token = torch.argmax(logits, dim=-1)
candidates.append(token)
# Verify candidates
candidate_seq = torch.cat([generated, torch.stack(candidates, dim=1).squeeze(2)], dim=1)
verify_out = model(candidate_seq)
verify_logits = verify_out.logits
# Check acceptance
num_accepted = 0
prefix_len = generated.shape[1]
for i, cand in enumerate(candidates):
pos_logits = verify_logits[:, prefix_len + i - 1, :]
probs = F.softmax(pos_logits, dim=-1)
if probs[0, cand.item()] >= acceptance_threshold:
num_accepted += 1
else:
break
# Accept or fallback
if num_accepted > 0:
accepted = torch.stack(candidates[:num_accepted], dim=1).squeeze(2)
generated = torch.cat([generated, accepted], dim=1)
tokens_generated += num_accepted
else:
next_token = torch.argmax(verify_logits[:, -1], dim=-1, keepdim=True)
generated = torch.cat([generated, next_token], dim=1)
tokens_generated += 1
return generateddef train_step(model, mtp_heads, batch, optimizer):
input_ids = batch['input_ids']
labels = batch['labels']
# Forward pass
outputs = model(input_ids, output_hidden_states=True)
hidden = outputs.hidden_states[-1]
# Standard LM loss
lm_loss = F.cross_entropy(
outputs.logits.view(-1, vocab_size),
labels.view(-1),
ignore_index=-100
)
# MTP loss
mtp_loss = mtp_heads.compute_loss(hidden, labels, weights=[1.0, 0.5, 0.25, 0.125])
# Combined
total_loss = lm_loss + 0.5 * mtp_loss
total_loss.backward()
optimizer.step()
optimizer.zero_grad()
return lm_loss.item(), mtp_loss.item()Two-phase approach:
Phase 1: Warmup (freeze base model)
model.eval()
for param in model.parameters():
param.requires_grad = False
optimizer = AdamW(mtp_heads.parameters(), lr=5e-4)
for step in range(warmup_steps):
with torch.no_grad():
hidden = model(**batch, output_hidden_states=True).hidden_states[-1]
loss = mtp_heads.compute_loss(hidden, batch['labels'])
loss.backward()
optimizer.step()
optimizer.zero_grad()Phase 2: Joint training
model.train()
for param in model.parameters():
param.requires_grad = True
optimizer = AdamW(
list(model.parameters()) + list(mtp_heads.parameters()),
lr=1e-5
)
for epoch in range(num_epochs):
for batch in train_data:
outputs = model(**batch, output_hidden_states=True)
lm_loss = outputs.loss
mtp_loss = mtp_heads.compute_loss(outputs.hidden_states[-1], batch['labels'])
total_loss = lm_loss + 0.3 * mtp_loss
total_loss.backward()
optimizer.step()
optimizer.zero_grad()Key parameters:
PARAMS = {
'num_future_tokens': [2, 4, 8],
'mtp_loss_weight': [0.1, 0.3, 0.5],
'position_weights': [[1.0, 0.5, 0.25, 0.125], [1.0, 1.0, 1.0, 1.0]],
'acceptance_threshold': [0.3, 0.5, 0.7],
}Guidelines:
- Start with 4 heads, 0.3 MTP weight
- Monitor validation perplexity (shouldn't degrade)
- Target 60-80% acceptance rate
- More heads = more speedup but harder training
class OptimizedMTPGenerator:
def __init__(self, model, mtp_heads, threshold=0.5, max_spec=4):
self.model = model
self.mtp_heads = mtp_heads
self.threshold = threshold
self.max_spec = max_spec
self.candidate_buffer = None
@torch.no_grad()
def generate(self, input_ids, max_new_tokens=100):
generated = input_ids.clone()
stats = {'total': 0, 'accepted': 0, 'forward_passes': 0}
kv_cache = None
while stats['total'] < max_new_tokens:
# Forward with caching
outputs = self.model(
generated,
past_key_values=kv_cache,
use_cache=True,
output_hidden_states=True
)
hidden = outputs.hidden_states[-1][:, -1:, :]
kv_cache = outputs.past_key_values
stats['forward_passes'] += 1
# Generate candidates
if self.candidate_buffer is None:
self.candidate_buffer = torch.zeros(
1, self.max_spec,
dtype=torch.long,
device=input_ids.device
)
for i, head in enumerate(self.mtp_heads.heads[:self.max_spec]):
logits = head(hidden)
self.candidate_buffer[0, i] = torch.argmax(logits[0, -1])
# Verify
verify_input = torch.cat([generated, self.candidate_buffer], dim=1)
verify_out = self.model(verify_input)
verify_logits = verify_out.logits
stats['forward_passes'] += 1
# Accept
num_accepted = 0
prefix_len = generated.shape[1]
for i in range(self.max_spec):
cand = self.candidate_buffer[0, i].item()
pos_logits = verify_logits[0, prefix_len + i - 1, :]
probs = F.softmax(pos_logits, dim=-1)
if probs[cand] >= self.threshold:
num_accepted += 1
else:
break
if num_accepted > 0:
accepted = self.candidate_buffer[:, :num_accepted]
generated = torch.cat([generated, accepted], dim=1)
stats['accepted'] += num_accepted
stats['total'] += num_accepted
else:
next_token = torch.argmax(verify_logits[0, -1]).unsqueeze(0).unsqueeze(0)
generated = torch.cat([generated, next_token], dim=1)
stats['total'] += 1
kv_cache = None # Clear for next iteration
stats['acceptance_rate'] = stats['accepted'] / stats['total']
return generated, stats@torch.jit.script
def fused_mtp_heads(
hidden: torch.Tensor,
head_weights: torch.Tensor
) -> torch.Tensor:
"""Fused computation of all heads"""
batch, seq, dim = hidden.shape
num_heads, vocab, _ = head_weights.shape
hidden_flat = hidden.view(-1, dim)
logits = torch.matmul(hidden_flat, head_weights.transpose(1, 2))
return logits.view(batch, seq, num_heads, vocab)class MemoryEfficientMTP:
def __init__(self, model, mtp_heads, use_checkpointing=True):
self.model = model
self.mtp_heads = mtp_heads
self.use_checkpointing = use_checkpointing
@torch.no_grad()
def generate(self, input_ids, max_new_tokens):
generated = input_ids
for _ in range(max_new_tokens):
if self.use_checkpointing:
candidates = torch.utils.checkpoint.checkpoint(
self._forward_heads, generated
)
else:
candidates = self._forward_heads(generated)
num_accepted = self._verify(generated, candidates)
if num_accepted > 0:
generated = torch.cat([generated, candidates[:num_accepted]], dim=1)
else:
logits = self.model(generated).logits
next_token = torch.argmax(logits[:, -1], dim=-1, keepdim=True)
generated = torch.cat([generated, next_token], dim=1)
return generatedSpeedup = E[accepted_tokens] / (1 + overhead)
≈ n · α^n / (1 + β)
where:
n = number of heads
α = acceptance rate per token
β = verification overhead
Example calculations:
def expected_speedup(acceptance_rate, num_heads, overhead=0.1):
if acceptance_rate >= 1.0:
expected = num_heads
else:
expected = acceptance_rate * (1 - acceptance_rate**num_heads) / (1 - acceptance_rate)
return expected / (1 + overhead)
# 70% acceptance, 4 heads, 10% overhead → 2.1x speedupModel Size Acceptance Rate
70B 75-85%
13B 68-75%
7B 60-70%
1B 45-55%
Task Acceptance
Code generation 70-80%
Math problems 65-75%
Creative writing 55-65%
Translation 70-80%
Standard (per token):
Forward pass: 100ms
Sampling: 1ms
Total: 101ms/token
MTP (per step):
Forward + hidden: 105ms
Head computation: 2ms
Verification: 100ms
Acceptance check: 1ms
Total: 208ms/step
With 3 tokens accepted:
208ms / 3 = 69ms/token → 1.46x speedup
With 4 tokens accepted:
208ms / 4 = 52ms/token → 1.94x speedup
Model: Llama-2-7B (14,336 MB)
Config Additional Total
4 heads, no trunk 524 MB (3.7%) 14,860 MB
4 heads, 1024 trunk 135 MB (0.9%) 14,471 MB
from vllm import LLM, SamplingParams
class MTLLMEngine(LLM):
def __init__(self, model_name, mtp_checkpoint, **kwargs):
super().__init__(model_name, **kwargs)
self.mtp_heads = MultiTokenPredictionHead.from_pretrained(mtp_checkpoint)
self.mtp_heads.eval().to(self.device)
def generate_with_mtp(self, prompts, sampling_params, threshold=0.5):
outputs = []
for prompt in prompts:
input_ids = self.tokenizer.encode(prompt)
generated = self._generate_single_mtp(input_ids, sampling_params, threshold)
outputs.append(generated)
return outputsimport tensorrt_llm
from tensorrt_llm.runtime import GenerationSession
class MTTensorRTEngine:
def __init__(self, engine_dir, mtp_heads_path, max_batch_size=8):
self.session = GenerationSession.from_engine(engine_dir)
self.mtp_heads = self._compile_mtp_heads(mtp_heads_path)
self.max_batch_size = max_batch_size
def _compile_mtp_heads(self, path):
import torch_tensorrt
heads = torch.load(path).eval()
compiled = torch_tensorrt.compile(
heads,
inputs=[torch_tensorrt.Input(
shape=(self.max_batch_size, 1, 4096),
dtype=torch.float16
)],
enabled_precisions={torch.float16}
)
return compiledimport deepspeed
class MTDeepSpeedEngine:
def __init__(self, model_name, mtp_heads, tensor_parallel=1):
self.model = deepspeed.init_inference(
model_name,
mp_size=tensor_parallel,
dtype=torch.float16
)
self.mtp_heads = self._shard_mtp_heads(mtp_heads, tensor_parallel)
def _shard_mtp_heads(self, heads, tp_size):
if tp_size == 1:
return heads
for head in heads.heads:
head.weight = deepspeed.utils.split_tensor_along_dim(
head.weight, dim=0, num_splits=tp_size
)
return headsSingle-sequence (batch_size=1), 512 tokens:
Llama-2-7B:
Method Latency Speedup
Standard 51.2s 1.00x
MTP (65% acc) 27.8s 1.84x
MTP (70% acc) 23.5s 2.18x
MTP (75% acc) 20.1s 2.55x
Llama-2-13B:
Standard 88.4s 1.00x
MTP (68% acc) 44.2s 2.00x
Llama-2-70B:
Standard 412s 1.00x
MTP (78% acc) 156s 2.64x
Batched (batch_size=32):
Llama-2-7B:
Method Throughput Speedup
Standard 3,200 tok/s 1.00x
MTP (65% acc) 5,440 tok/s 1.70x
MTP (70% acc) 6,240 tok/s 1.95x
MTP (75% acc) 7,040 tok/s 2.20x
HumanEval (code, pass@1):
Base Model: 26.8%
MTP Inference: 26.8% (no degradation)
MTP Trained: 28.1% (+1.3% improvement)
MMLU (accuracy):
Base Model: 45.2%
MTP Inference: 45.2%
MTP Trained: 46.7% (+1.5% improvement)
Key Finding: Training with MTP improves quality on downstream tasks!
Position-wise acceptance rates:
Position 1: 76% ████████████████
Position 2: 68% ██████████████
Position 3: 55% ███████████
Position 4: 38% ████████
Average: 2.37 tokens accepted per step
Training MTP from scratch (Llama-2-7B):
Standard: 14 days (8x A100)
MTP: 16 days (8x A100)
Overhead: +14%
Fine-tuning:
Warmup: 6 hours
Joint: 24 hours
Total: 30 hours
Production (1M requests/day, 512 tokens each):
Standard:
- GPU hours: 16,000 hrs/day
- Cost: $32,000/day
MTP (2x speedup):
- GPU hours: 8,000 hrs/day
- Cost: $16,000/day
- Savings: $480k/month
ROI: Pays for training in <1 week
Method Speedup Memory Training Quality
Standard 1.00x 1.00x N/A Baseline
MTP 2.20x 1.02x Required +1.5%
Speculative 2.50x 1.50x None 0%
Medusa 2.30x 1.01x FT only 0%
EAGLE 3.00x 1.01x FT only 0%
Lookahead 1.80x 1.10x None 0%
Use MTP when: ✅ Training/fine-tuning anyway ✅ Quality improvement valuable ✅ Tight memory budget ✅ 2-3x speedup sufficient
Don't use MTP when: ❌ Cannot retrain model ❌ Need max speedup (use speculative) ❌ Very tight latency requirements ❌ Model < 1B (low acceptance)
# 7B models
CONFIG_7B = {
'num_heads': 4,
'shared_trunk': True,
'trunk_dim': 1024,
'acceptance_threshold': 0.5,
}
# 13B+ models
CONFIG_13B = {
'num_heads': 6,
'shared_trunk': True,
'trunk_dim': 2048,
'acceptance_threshold': 0.4,
}
# 70B+ models
CONFIG_70B = {
'num_heads': 8,
'shared_trunk': True,
'trunk_dim': 4096,
'acceptance_threshold': 0.3,
}Multi-token prediction offers compelling benefits:
- Training: +1-2% quality improvement
- Inference: 2-3x speedup
- Simplicity: No draft model needed
- Flexibility: Combines with other optimizations
Best used when training/fine-tuning and when quality improvements justify training cost.
- MTP + Continuous Batching: 10-15x throughput
- MTP + Quantization: Maintain speedup with less memory
- MTP + Prefix Caching: Stack speedups
Papers:
Code:
- Nexus:
/nexus/components/inference/multi_token.py - Examples:
/examples/training/multi_token_training.py - Benchmarks:
/benchmarks/inference/mtp_benchmark.py