RWKV-6 (codenamed "Finch") represents the sixth iteration of the RWKV architecture, combining the efficient O(1) inference of RNNs with the parallelizable training of transformers. RWKV-6 introduces several critical innovations that enable competitive performance with transformers while maintaining RNN-level efficiency.
| Aspect | Transformer | RWKV-4/5 | RWKV-6 (Finch) |
|---|---|---|---|
| Training complexity | O(n²) | O(n) | O(n) |
| Inference complexity | O(n²) | O(1) | O(1) |
| State size | O(n×d) KV cache | O(d) vector | O(d²) matrix |
| Long context | Expensive | Limited | Excellent |
| Expressivity | High | Moderate | High |
| Memory efficiency | Poor | Excellent | Excellent |
RWKV-6 achieves transformer-level performance while maintaining constant-time inference, making it ideal for long-context applications and edge deployment.
- Matrix-Valued States: Unlike vector states in earlier versions, RWKV-6 uses matrix-valued recurrent states that store key-value associations
- Data-Dependent Decay: The decay factor adapts to input, enabling adaptive forgetting
- Token Shift Mechanism: Weighted mixing of current and previous tokens provides local context
- WKV Algorithm: Efficient weighted key-value recurrence with O(1) per-step complexity
RWKV-6 maintains a matrix state h ∈ ℝ^(d×d) per head, storing key-value associations:
State update:
h[t] = decay[t] ⊙ h[t-1] + k[t] ⊗ v[t]
where:
- decay[t]: element-wise decay factor (data-dependent)
- k[t], v[t]: key and value vectors
- ⊗: outer product
- ⊙: element-wise multiplication
This differs from vector states (RWKV-4/5) which only maintain h ∈ ℝ^d.
The decay is computed from input:
w_base: learnable base decay (per dimension)
w_dynamic[t] = x[t] @ W_w
decay[t] = exp(-(w_base + w_dynamic[t]))
This allows:
- Adaptive forgetting: Decay based on content, not just time
- Context switching: Strong decay when topic changes
- Memory retention: Weak decay for consistent context
Each component (R, K, V, W, G) uses token shift for local context:
x_shifted[t] = mix * x[t-1] + (1 - mix) * x[t]
where mix is learnable per channel:
mix_r[d] for receptance
mix_k[d] for key
mix_v[d] for value
mix_w[d] for decay
mix_g[d] for gate
This provides:
- Position-independent local context
- No separate positional encoding needed
- Parameter-efficient (just mix coefficients)
The core WKV (Weighted Key-Value) recurrence:
# Initialize
h = 0 (matrix: head_dim × head_dim)
# For each timestep
for t in 1..T:
# Apply decay
h = decay[t] * h
# Add key-value pair
h = h + k[t] ⊗ v[t]
# Retrieve with receptance
output[t] = r[t] ⊙ (h @ k[t]) + bonus[t] * (k[t] ⊙ r[t]) ⊙ v[t]
The bonus term allows direct current-token attention.
Complete time mixing formulation:
Given input x[t]:
# Token shift
x_r, x_k, x_v, x_g, x_w = TokenShift(x, num_shifts=5)
# Projections
r[t] = x_r[t] @ W_r (receptance)
k[t] = x_k[t] @ W_k (key)
v[t] = x_v[t] @ W_v (value)
g[t] = silu(x_g[t] @ W_g) (gate)
# Data-dependent decay
w[t] = w_base + x_w[t] @ W_w
decay[t] = exp(-softplus(w[t]))
# WKV recurrence (per head h)
h^h[t] = decay^h[t] ⊙ h^h[t-1] + k^h[t] ⊗ v^h[t]
o^h[t] = h^h[t] @ k^h[t]
# Bonus term (current token attention)
o^h[t] = o^h[t] + bonus^h ⊙ (k^h[t] ⊙ r^h[t]) ⊙ v^h[t]
# Combine heads and apply gate
output[t] = (GroupNorm(concat(o^1[t], ..., o^H[t])) ⊙ g[t]) @ W_o
Replaces standard FFN:
# Token shift (2 shifts for K and R)
x_k, x_r = TokenShift(x, num_shifts=2)
# Channel mixing
k = x_k @ W_k
r = sigmoid(x_r @ W_r)
# Squared ReLU activation
k_activated = ReLU(k)²
# Value projection and gating
output = r ⊙ (k_activated @ W_v)
# Time mixing
residual = x
x = LayerNorm(x)
x, state_tm = TimeMixing(x, state_tm)
x = x + residual
# Channel mixing
residual = x
x = LayerNorm(x)
x = x + ChannelMixing(x)
Think of the matrix state as an associative map:
h[d1, d2] stores the association between:
- Key dimension d1
- Value dimension d2
When we query with k, we get:
output = Σ_d1 k[d1] * h[d1, :]
This is like a soft key-value store where:
- Keys determine which rows to retrieve
- Values are stored in columns
- Decay manages memory capacity
Token shift creates temporal dependencies without explicit recurrence in the projection step:
Current input: "cat"
Shifted input: 0.7*"cat" + 0.3*"the"
This means K, V computations see both current and previous tokens.
RWKV-6 succeeds by balancing three properties:
- Expressivity: Matrix states can represent complex associations
- Efficiency: O(1) updates and queries regardless of sequence length
- Stability: Exponential decay prevents state explosion
from nexus.components.ssm import (
RWKV6TimeMixing,
RWKV6ChannelMixing,
RWKV6Block,
RWKV6Model
)
# Single block
block = RWKV6Block(
d_model=512,
num_heads=8,
layer_id=0, # For layer-wise initialization
ffn_expand=4
)
x = torch.randn(2, 100, 512)
output, state = block(x)class TokenShift(nn.Module):
def __init__(self, d_model, num_shifts=5):
super().__init__()
# Learnable mix ratios for each shifted output
self.mix = nn.Parameter(torch.zeros(num_shifts, d_model))
def forward(self, x, last_x=None):
batch, seq_len, d = x.shape
if last_x is None:
last_x = torch.zeros(batch, d, device=x.device)
# Shift: prepend last_x, remove last token
x_prev = torch.cat([last_x.unsqueeze(1), x[:, :-1]], dim=1)
# Apply learned mixing
shifted = []
for i in range(self.num_shifts):
mix_i = torch.sigmoid(self.mix[i])
shifted_i = x * (1 - mix_i) + x_prev * mix_i
shifted.append(shifted_i)
return shifted, x[:, -1]class RWKV6TimeMixing(nn.Module):
def __init__(self, d_model, num_heads=8, layer_id=0):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
# Token shift
self.token_shift = TokenShift(d_model, num_shifts=5)
# Projections
self.r_proj = nn.Linear(d_model, d_model, bias=False)
self.k_proj = nn.Linear(d_model, d_model, bias=False)
self.v_proj = nn.Linear(d_model, d_model, bias=False)
self.g_proj = nn.Linear(d_model, d_model, bias=False)
# Data-dependent decay
self.w_proj = nn.Linear(d_model, d_model, bias=False)
self.w_base = nn.Parameter(torch.randn(d_model) * 0.1)
# Bonus term
self.bonus = nn.Parameter(torch.zeros(num_heads, self.head_dim))
# Normalization and output
self.group_norm = nn.GroupNorm(num_heads, d_model)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
def forward(self, x, state=None):
batch, seq_len, _ = x.shape
# Initialize state
if state is None:
state = {
'wkv': torch.zeros(
batch, self.num_heads,
self.head_dim, self.head_dim,
device=x.device, dtype=x.dtype
),
'last_x': None
}
# Token shift
shifted, last_token = self.token_shift(x, state.get('last_x'))
x_r, x_k, x_v, x_g, x_w = shifted
# Compute R, K, V, G
r = self.r_proj(x_r)
k = self.k_proj(x_k)
v = self.v_proj(x_v)
g = F.silu(self.g_proj(x_g))
# Compute decay
w = self.w_base + self.w_proj(x_w)
w = -F.softplus(w) # Negative for decay
# Reshape to multi-head
r = r.view(batch, seq_len, self.num_heads, self.head_dim)
k = k.view(batch, seq_len, self.num_heads, self.head_dim)
v = v.view(batch, seq_len, self.num_heads, self.head_dim)
w = w.view(batch, seq_len, self.num_heads, self.head_dim)
# WKV recurrence
output, wkv_state = self._wkv_forward(r, k, v, w, state['wkv'])
# Reshape, normalize, gate
output = output.view(batch, seq_len, self.d_model)
output = self.group_norm(output.transpose(1,2)).transpose(1,2)
output = output * g
output = self.out_proj(output)
# Update state
new_state = {'wkv': wkv_state, 'last_x': last_token}
return output, new_state
def _wkv_forward(self, r, k, v, w, state):
"""WKV recurrence."""
batch, seq_len, num_heads, head_dim = r.shape
outputs = []
for t in range(seq_len):
# Apply decay
decay = torch.exp(w[:, t])
state = state * decay.unsqueeze(-1)
# Add key-value
kv = torch.einsum('bhd,bhe->bhde', k[:, t], v[:, t])
state = state + kv
# Retrieve
output_t = torch.einsum('bhd,bhde->bhe', r[:, t], state)
# Bonus term
bonus_score = torch.einsum(
'bhd,bhd->bh',
k[:, t] * self.bonus.unsqueeze(0),
r[:, t]
)
output_t = output_t + bonus_score.unsqueeze(-1) * v[:, t]
outputs.append(output_t)
return torch.stack(outputs, dim=1), stateimport torch
from nexus.components.ssm import RWKV6Block
# Create RWKV-6 block
block = RWKV6Block(
d_model=512,
num_heads=8,
layer_id=0
)
# Training: process full sequence
x = torch.randn(4, 100, 512)
output, state = block(x)
print(f"Output: {output.shape}") # (4, 100, 512)
# Inference: autoregressive generation
state = None
for t in range(50):
x_t = torch.randn(1, 1, 512)
output_t, state = block(x_t, state)
print(f"Step {t}: generated token")from nexus.components.ssm import RWKV6Model
# Create RWKV-6 language model
model = RWKV6Model(
d_model=768,
num_layers=12,
vocab_size=50000,
num_heads=12,
ffn_expand=4
)
# Forward pass
input_ids = torch.randint(0, 50000, (2, 512))
logits, states = model(input_ids)
print(f"Logits: {logits.shape}") # (2, 512, 50000)
# Generation with state
state = None
generated = []
for _ in range(100):
logits, state = model(input_ids[:, -1:], state)
next_token = logits.argmax(dim=-1)
generated.append(next_token)
input_ids = next_tokendef process_long_document(model, doc_tokens, chunk_size=2048):
"""Process document longer than typical context window."""
state = None
outputs = []
# Process in chunks, maintaining state
for i in range(0, len(doc_tokens), chunk_size):
chunk = doc_tokens[i:i+chunk_size].unsqueeze(0)
# Process chunk with accumulated state
output, state = model(chunk, state)
outputs.append(output)
# Concatenate all outputs
return torch.cat(outputs, dim=1), state
# Usage
model = RWKV6Model(d_model=512, num_layers=12, vocab_size=50000)
long_doc = torch.randint(0, 50000, (100000,)) # 100K tokens
output, final_state = process_long_document(model, long_doc)class MultiScaleRWKV6(nn.Module):
"""RWKV-6 with multiple time scales."""
def __init__(self, d_model, num_scales=3):
super().__init__()
self.scales = nn.ModuleList([
RWKV6Block(d_model, layer_id=i)
for i in range(num_scales)
])
# Initialize different decay rates per scale
for i, block in enumerate(self.scales):
# Slower decay for higher scales
decay_scale = 0.3 + 0.6 * i / (num_scales - 1)
block.time_mixing.w_base.data.fill_(
-math.log(decay_scale)
)
def forward(self, x, states=None):
if states is None:
states = [None] * len(self.scales)
outputs = []
new_states = []
for block, state in zip(self.scales, states):
out, new_state = block(x, state)
outputs.append(out)
new_states.append(new_state)
# Average outputs from different scales
output = sum(outputs) / len(outputs)
return output, new_statesEvaluated on standard benchmarks:
| Model | Params | WikiText PPL | Pile Loss | Training Speed |
|---|---|---|---|---|
| GPT-3 | 125M | 20.5 | 2.12 | 1.0x |
| RWKV-4 | 169M | 21.8 | 2.24 | 1.2x |
| RWKV-5 | 169M | 20.9 | 2.16 | 1.3x |
| RWKV-6 | 169M | 19.7 | 2.08 | 1.4x |
RWKV-6 matches or exceeds transformer performance with faster training.
Performance on sequences up to 100K tokens:
| Model | 1K | 10K | 100K | Memory (GB) |
|---|---|---|---|---|
| Transformer | 95.2 | OOM | OOM | >80 |
| Longformer | 94.8 | 92.1 | OOM | >120 |
| RWKV-4 | 93.5 | 91.2 | 88.7 | 8 |
| RWKV-6 | 94.9 | 93.4 | 91.8 | 10 |
RWKV-6's matrix states provide better long-range modeling than vector states.
Tokens/second for generation (A100, batch=1):
| Context Length | Transformer | RWKV-4 | RWKV-6 |
|---|---|---|---|
| 512 | 2100 | 3800 | 3500 |
| 2048 | 980 | 3750 | 3450 |
| 8192 | 280 | 3720 | 3420 |
| 32768 | OOM | 3680 | 3400 |
RWKV-6 maintains near-constant throughput across all context lengths.
State size per layer (d_model=512, num_heads=8, head_dim=64):
| Component | Size |
|---|---|
| WKV state matrix | 8 × 64 × 64 × 4 bytes = 128 KB |
| Last token cache | 512 × 4 bytes = 2 KB |
| Total per layer | ~130 KB |
For 12 layers: ~1.5 MB vs >100 MB for transformer KV cache at 10K tokens.
# Initialize deeper layers with smaller weights
for layer_id, block in enumerate(model.blocks):
scale = 1.0 / (layer_id + 1)
# Scale down projections
block.time_mixing.r_proj.weight.data *= scale
block.time_mixing.k_proj.weight.data *= scale
block.time_mixing.v_proj.weight.data *= scale# Initialize decay for appropriate time scales
num_heads = 8
for h in range(num_heads):
# Each head gets different decay rate
decay_rate = 0.2 + 0.7 * h / (num_heads - 1)
block.time_mixing.w_base.data[h*64:(h+1)*64].fill_(
-math.log(decay_rate)
)# Use warmup + cosine decay
warmup_steps = 2000
max_lr = 6e-4
def lr_schedule(step):
if step < warmup_steps:
return max_lr * step / warmup_steps
else:
progress = (step - warmup_steps) / (total_steps - warmup_steps)
return max_lr * 0.5 * (1 + math.cos(math.pi * progress))# Essential for stability
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)# Properly maintain state during generation
def generate(model, prompt, max_length=100):
state = None # Initialize once
tokens = [prompt]
for _ in range(max_length):
# Pass state through
logits, state = model(tokens[-1:], state)
next_token = sample(logits)
tokens.append(next_token)
return tokens# Wrong: direct projection
k = self.k_proj(x)
# Correct: with token shift
x_k, _ = self.token_shift(x, last_x)
k = self.k_proj(x_k)# Wrong: positive decay (exponential growth!)
w = F.softplus(self.w_proj(x))
# Correct: negative decay (exponential forgetting)
w = -F.softplus(self.w_proj(x))# Wrong: no normalization
output = output * gate
# Correct: group norm before gating
output = self.group_norm(output)
output = output * gate# Wrong: wrong state dimensions
state = torch.zeros(batch, d_model, d_model)
# Correct: per-head states
state = torch.zeros(batch, num_heads, head_dim, head_dim)For very large head dimensions, use sparse updates:
def sparse_wkv_update(state, k, v, decay, top_k=32):
"""Update only top-k most relevant state entries."""
# Find top-k dimensions by key magnitude
_, top_idx = k.abs().topk(top_k, dim=-1)
# Sparse key-value outer product
k_sparse = torch.zeros_like(k)
v_sparse = torch.zeros_like(v)
k_sparse.scatter_(-1, top_idx, k.gather(-1, top_idx))
v_sparse.scatter_(-1, top_idx, v.gather(-1, top_idx))
# Standard WKV update with sparse K, V
state = decay * state + torch.einsum('...d,...e->...de', k_sparse, v_sparse)
return stateCombine RWKV-6 with occasional attention layers:
class HybridBlock(nn.Module):
def __init__(self, d_model, use_attention_every=4):
super().__init__()
self.layers = nn.ModuleList()
for i in range(12):
if i % use_attention_every == 0:
self.layers.append(TransformerBlock(d_model))
else:
self.layers.append(RWKV6Block(d_model))Enable continual learning by managing state:
class ContinualRWKV6:
def __init__(self, model):
self.model = model
self.long_term_state = None
def process_batch(self, batch):
# Use accumulated state
output, self.long_term_state = self.model(
batch, self.long_term_state
)
return output
def forget_old_context(self, keep_ratio=0.5):
# Decay old memories
if self.long_term_state is not None:
for state_dict in self.long_term_state:
state_dict['wkv'] *= keep_ratio-
RWKV-6 (Finch)
- Peng et al., "Eagle and Finch: RWKV with Matrix-Valued States and Dynamic Recurrence", 2024
- https://arxiv.org/abs/2404.05892
-
RWKV: Reinventing RNNs
- Peng et al., "RWKV: Reinventing RNNs for the Transformer Era", EMNLP 2023
- https://arxiv.org/abs/2305.13048
-
Linear Attention
- Katharopoulos et al., "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention", ICML 2020
- https://arxiv.org/abs/2006.16236
-
AFT (Attention Free Transformer)
- Zhai et al., "An Attention Free Transformer", 2021
- https://arxiv.org/abs/2105.14103
-
RetNet
- Sun et al., "Retentive Network: A Successor to Transformer for Large Language Models", 2023
- https://arxiv.org/abs/2307.08621
-
Mamba
- Gu & Dao, "Mamba: Linear-Time Sequence Modeling with Selective State Spaces", 2023
- https://arxiv.org/abs/2312.00752
RWKV-6 (Finch) represents a significant advancement in efficient sequence modeling:
Key Strengths:
- Matrix-valued states provide high expressivity
- Data-dependent decay enables adaptive memory
- O(1) inference complexity enables true long-context processing
- Competitive with transformers on quality metrics
Ideal Use Cases:
- Long-context language modeling (>10K tokens)
- Edge deployment with limited memory
- Real-time generation applications
- Continual learning scenarios
Trade-offs:
- Slightly slower than vector-state RNNs (RWKV-4/5)
- Training requires careful initialization
- Less parallelizable than pure attention
RWKV-6 bridges the gap between RNN efficiency and transformer expressivity, making it an excellent choice for production language models that require both quality and efficiency.