Linear RNNs represent a foundational class of sequence models that achieve linear time complexity O(n) while avoiding the quadratic complexity O(n²) of standard attention mechanisms. They serve as the base architecture for many modern efficient sequence models including RWKV, Mamba, DeltaNet, and RetNet.
| Aspect | Standard Attention | Linear RNN |
|---|---|---|
| Training complexity | O(n²) | O(n) |
| Inference complexity | O(n²) | O(1) per token |
| Memory | O(n²) | O(d) state |
| Parallelization | Full sequence | Chunk-wise possible |
| Long sequences | Prohibitive | Efficient |
| Hardware efficiency | Tensor cores | Memory-bound |
Linear RNNs enable efficient processing of long sequences (100k+ tokens) that would be infeasible for standard transformers, while maintaining constant-time inference complexity.
The fundamental insight is that certain recurrent architectures can be computed in two equivalent modes:
-
Recurrent mode (inference): O(1) per step
h[t] = f(h[t-1], x[t]) y[t] = g(h[t]) -
Parallel mode (training): O(n) or O(n log n) via convolution
y = conv(kernel, x) or parallel_scan(x)
This duality enables efficient training (parallel) while maintaining RNN-like inference efficiency (constant time per token).
Linear RNNs are based on linear time-invariant recurrences:
h[t] = A h[t-1] + B x[t]
y[t] = C h[t] + D x[t]
where:
h[t]∈ ℝ^N: hidden statex[t]∈ ℝ^d: inputy[t]∈ ℝ^d: outputA∈ ℝ^(N×N): state transition matrixB∈ ℝ^(N×d): input matrixC∈ ℝ^(d×N): output matrixD∈ ℝ^(d×d): feedthrough matrix (often 0)
The recurrent formulation can be "unrolled" into a convolution:
y[t] = Σ_{k=0}^{t} C A^k B x[t-k] + D x[t]
Let K[k] = C A^k B be the convolution kernel. Then:
y = K * x (convolution)
This convolution can be computed efficiently:
- Time domain: O(n²) naive, O(n log n) via FFT
- Frequency domain: O(n log n) via FFT
- Parallel scan: O(n log n) with better constants
Pure linear RNNs are limited in expressivity. Modern linear RNNs add:
-
Gating mechanisms: Control information flow
g[t] = σ(W_g x[t]) y[t] = g[t] ⊙ (C h[t]) -
Input-dependent parameters: Make A, B, C depend on input
A[t] = f_A(x[t]) B[t] = f_B(x[t]) -
Short convolutions: Add local context before recurrence
x'[t] = Conv1d(x)[t] h[t] = A h[t-1] + B x'[t]
Efficient state management is crucial for linear RNNs:
# Training: parallel over sequence
h_all = parallel_compute(x_all) # (batch, seq_len, state_dim)
# Inference: sequential, constant memory
h_t = recurrent_step(h_prev, x_t) # (batch, state_dim)The state typically has constant size O(d) or O(d²) for matrix-valued states, regardless of sequence length.
The full Linear RNN block consists of:
# Input projection
x_branch, z = split(Linear(x)) # (batch, seq, hidden_dim) each
# Short convolution (optional)
x_conv = Conv1d(x_branch) # Local context
# Activation
x_active = silu(x_conv)
# Recurrent computation (implemented by subclasses)
y = recurrence(x_active, state)
# Normalization
y = LayerNorm(y)
# Gating
y_gated = y * silu(z)
# Output projection
output = Linear(y_gated)
Many linear RNNs use short depthwise convolutions to capture local patterns:
Conv1d:
- Kernel size: 3-4 (small receptive field)
- Groups: d_inner (depthwise, one filter per channel)
- Padding: k-1 (causal, no future leakage)
- Output: y[:, :seq_len] (trim to sequence length)
This provides:
- Local inductive bias: Nearby tokens are related
- Position information: Implicit relative positions
- Efficiency: Depthwise conv is cheap (one filter per channel)
States should be initialized properly:
# Vector state (simple RNN)
state = zeros(batch, hidden_dim)
# Matrix state (key-value style)
state = zeros(batch, num_heads, head_dim, head_dim)
# Multi-component state (complex architectures)
state = {
'kv_state': zeros(batch, num_heads, head_dim, head_dim),
'denominator': zeros(batch, num_heads, head_dim),
'last_token': None
}Common recurrence patterns in linear RNNs:
Exponential decay (RWKV, RetNet):
h[t] = w[t] ⊙ h[t-1] + k[t] ⊗ v[t]
where w[t] ∈ (0, 1) is decay factor.
Additive update (S4, Mamba):
h[t] = A h[t-1] + B[t] x[t]
where A is state transition, B is input projection.
Delta rule (DeltaNet, RWKV-7):
error = v[t] - h[t-1] @ k[t]
h[t] = decay * h[t-1] + beta[t] * k[t] ⊗ error
where error correction drives updates.
Think of Linear RNNs as a hierarchy of abstractions:
- Bottom layer (computation): Efficiently compute recurrent operations via convolution or parallel scan
- Middle layer (state): Maintain compact state that summarizes past context
- Top layer (gating): Control what information flows through via learned gates
The key tradeoff compared to attention:
- Attention: Full access to all past tokens (O(n²) but flexible)
- Linear RNN: Compressed summary via state (O(n) but lossy)
Success depends on:
- State capacity: Large enough to capture relevant history
- Gating quality: Smart enough to filter irrelevant information
- Inductive bias: Architecture matches task structure
class LinearRNN(NexusModule):
"""Base Linear RNN architecture.
Provides common infrastructure for linear recurrent models.
Subclasses implement specific recurrence patterns.
Args:
dim: Model dimension
expand: Hidden expansion factor (default: 2)
bias: Use bias in projections (default: True)
use_short_conv: Use convolution for local context (default: True)
conv_size: Convolution kernel size (default: 4)
"""
def __init__(
self,
dim: int,
expand: int = 2,
bias: bool = True,
use_short_conv: bool = True,
conv_size: int = 4
):
super().__init__()
self.dim = dim
self.expand = expand
self.hidden_dim = dim * expand
# Input projection (2x for main branch + gate)
self.in_proj = nn.Linear(dim, self.hidden_dim * 2, bias=bias)
# Short convolution
if use_short_conv:
self.conv = nn.Conv1d(
in_channels=self.hidden_dim,
out_channels=self.hidden_dim,
kernel_size=conv_size,
padding=conv_size - 1, # Causal padding
groups=self.hidden_dim, # Depthwise
bias=bias
)
# Output projection
self.out_proj = nn.Linear(self.hidden_dim, dim, bias=bias)
# Normalization
self.norm = nn.LayerNorm(self.hidden_dim)
def forward(
self,
x: torch.Tensor,
state: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
x: Input (batch, seq_len, dim)
state: Recurrent state (optional)
Returns:
output: Output (batch, seq_len, dim)
state: Updated state
"""
batch_size, seq_len, _ = x.shape
# Project input
xz = self.in_proj(x)
x_branch, z = xz.chunk(2, dim=-1)
# Apply short convolution
if hasattr(self, 'conv'):
x_branch = x_branch.transpose(1, 2) # (B, H, L)
x_branch = self.conv(x_branch)[:, :, :seq_len] # Causal
x_branch = x_branch.transpose(1, 2) # (B, L, H)
# Activation
x_branch = F.silu(x_branch)
# Recurrent computation (subclass-specific)
y, state = self.recurrent_forward(x_branch, state)
# Normalize and gate
y = self.norm(y)
y = y * F.silu(z)
# Project output
output = self.out_proj(y)
return output, state
def recurrent_forward(
self,
x: torch.Tensor,
state: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Recurrent computation. Override in subclasses."""
return x, stateclass ShortConvolution(NexusModule):
"""Short depthwise convolution for local context.
Captures local patterns before recurrent processing.
Supports both parallel and incremental modes.
Args:
dim: Channel dimension
kernel_size: Convolution kernel size (default: 4)
bias: Use bias (default: True)
causal: Causal convolution (default: True)
"""
def __init__(
self,
dim: int,
kernel_size: int = 4,
bias: bool = True,
causal: bool = True
):
super().__init__()
self.dim = dim
self.kernel_size = kernel_size
self.causal = causal
padding = kernel_size - 1 if causal else kernel_size // 2
self.conv = nn.Conv1d(
in_channels=dim,
out_channels=dim,
kernel_size=kernel_size,
padding=padding,
groups=dim, # Depthwise
bias=bias
)
def forward(
self,
x: torch.Tensor,
state: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
x: Input (batch, seq_len, dim)
state: Conv cache for incremental decoding
Returns:
output: Convolved output
state: Updated conv cache
"""
batch_size, seq_len, dim = x.shape
# Incremental decoding (single token)
if state is not None and seq_len == 1:
# Concatenate with cached history
x_cache = torch.cat([state, x], dim=1)
x_cache = x_cache.transpose(1, 2)
y = self.conv(x_cache)[:, :, -1:]
y = y.transpose(1, 2)
# Update cache (keep last kernel_size-1 positions)
new_state = x_cache.transpose(1, 2)[:, -(self.kernel_size-1):, :]
return y, new_state
# Parallel processing (full sequence)
x = x.transpose(1, 2) # (B, dim, seq)
y = self.conv(x)
if self.causal:
y = y[:, :, :seq_len] # Trim excess
y = y.transpose(1, 2) # (B, seq, dim)
# Cache for future incremental decoding
if seq_len >= self.kernel_size - 1:
new_state = x.transpose(1, 2)[:, -(self.kernel_size-1):, :]
else:
new_state = x.transpose(1, 2)
return y, new_stateHere are implementations of common recurrence patterns used in linear RNNs:
# 1. Exponential decay (RWKV-style)
def exponential_decay_recurrence(k, v, w, state):
"""
Args:
k: Key (batch, seq, heads, head_dim)
v: Value (batch, seq, heads, head_dim)
w: Decay (batch, seq, heads, head_dim)
state: (batch, heads, head_dim, head_dim)
"""
outputs = []
for t in range(seq_len):
# Apply decay
decay = torch.exp(w[:, t]) # (batch, heads, head_dim)
state = state * decay.unsqueeze(-1)
# Add new key-value
kv = torch.einsum('bhd,bhe->bhde', k[:, t], v[:, t])
state = state + kv
outputs.append(state)
return torch.stack(outputs, dim=1), state
# 2. Linear SSM (S4-style)
def linear_ssm_recurrence(x, A, B, C, state):
"""
Args:
x: Input (batch, seq, dim)
A: Transition (state_dim, state_dim)
B: Input projection (state_dim, dim)
C: Output projection (dim, state_dim)
state: (batch, state_dim)
"""
outputs = []
for t in range(seq_len):
# State update: h = A @ h + B @ x
state = torch.matmul(state, A.T) + torch.matmul(x[:, t], B.T)
# Output: y = C @ h
y = torch.matmul(state, C.T)
outputs.append(y)
return torch.stack(outputs, dim=1), state
# 3. Gated update (retention-style)
def gated_recurrence(q, k, v, decay, state):
"""
Args:
q, k, v: Query, Key, Value (batch, seq, heads, head_dim)
decay: Decay factor per head (heads,)
state: (batch, heads, head_dim, head_dim)
"""
outputs = []
for t in range(seq_len):
# Decay state
state = decay.view(1, -1, 1, 1) * state
# Add outer product
kv = torch.einsum('bhd,bhe->bhde', k[:, t], v[:, t])
state = state + kv
# Query state
o = torch.einsum('bhd,bhde->bhe', q[:, t], state)
outputs.append(o)
return torch.stack(outputs, dim=1), stateimport torch
from nexus.components.ssm import LinearRNN, ShortConvolution
# Create base linear RNN
model = LinearRNN(
dim=512,
expand=2,
use_short_conv=True,
conv_size=4
)
# Forward pass (training)
x = torch.randn(2, 100, 512) # (batch, seq_len, dim)
output, state = model(x)
print(f"Output shape: {output.shape}") # (2, 100, 512)
print(f"State shape: {state.shape}") # (2, 1024) - expanded dim
# Incremental decoding (inference)
state = model.init_state(batch_size=1, device='cuda')
for t in range(100):
x_t = torch.randn(1, 1, 512).cuda() # Single token
output_t, state = model(x_t, state)
print(f"Step {t}: {output_t.shape}") # (1, 1, 512)class MyCustomRNN(LinearRNN):
"""Custom linear RNN with specific recurrence pattern."""
def __init__(self, dim, expand=2):
super().__init__(dim, expand)
# Custom parameters for recurrence
self.W_state = nn.Linear(self.hidden_dim, self.hidden_dim, bias=False)
self.decay = nn.Parameter(torch.randn(self.hidden_dim))
def recurrent_forward(self, x, state):
"""Implement custom recurrence."""
batch_size, seq_len, hidden_dim = x.shape
if state is None:
state = torch.zeros(batch_size, hidden_dim, device=x.device)
outputs = []
for t in range(seq_len):
# Custom recurrence: h = decay * h + W @ x
state = torch.sigmoid(self.decay) * state + self.W_state(x[:, t])
outputs.append(state)
output = torch.stack(outputs, dim=1)
return output, state
# Use custom RNN
custom_rnn = MyCustomRNN(dim=512)
x = torch.randn(2, 100, 512)
output, state = custom_rnn(x)from nexus.components.ssm import ShortConvolution
# Create short convolution
conv = ShortConvolution(
dim=512,
kernel_size=4,
causal=True
)
# Parallel mode (training)
x = torch.randn(2, 100, 512)
y, cache = conv(x)
print(f"Output shape: {y.shape}") # (2, 100, 512)
# Incremental mode (inference)
cache = None
for t in range(10):
x_t = torch.randn(2, 1, 512)
y_t, cache = conv(x_t, cache)
print(f"Step {t}: output {y_t.shape}, cache {cache.shape}")
# output (2, 1, 512), cache (2, 3, 512) - last k-1 tokensclass ExponentialDecayRNN(LinearRNN):
"""Linear RNN with exponential decay (like RWKV)."""
def __init__(self, dim, expand=2, num_heads=8):
super().__init__(dim, expand)
self.num_heads = num_heads
self.head_dim = self.hidden_dim // num_heads
# Per-head decay factors
self.decay = nn.Parameter(torch.randn(num_heads, self.head_dim))
# Projections
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=False)
def recurrent_forward(self, x, state):
batch, seq_len, _ = x.shape
# Initialize state (matrix-valued)
if state is None:
state = torch.zeros(
batch, self.num_heads, self.head_dim, self.head_dim,
device=x.device, dtype=x.dtype
)
# Project to K, V
k = self.k_proj(x).view(batch, seq_len, self.num_heads, self.head_dim)
v = self.v_proj(x).view(batch, seq_len, self.num_heads, self.head_dim)
# Recurrent processing
outputs = []
decay_factor = torch.sigmoid(self.decay) # Bound to (0, 1)
for t in range(seq_len):
# Decay state
state = state * decay_factor.unsqueeze(0).unsqueeze(-1)
# Add new key-value
kv = torch.einsum('bhd,bhe->bhde', k[:, t], v[:, t])
state = state + kv
# Read from state
output_t = torch.einsum('bhde,bhe->bhd', state, v[:, t])
outputs.append(output_t)
# Stack and reshape
output = torch.stack(outputs, dim=1)
output = output.reshape(batch, seq_len, self.hidden_dim)
return output, state
# Usage
rnn = ExponentialDecayRNN(dim=512, num_heads=8)
x = torch.randn(2, 100, 512)
output, state = rnn(x)
print(f"Output: {output.shape}, State: {state.shape}")
# Output: (2, 100, 512), State: (2, 8, 64, 64)Comparison of state memory requirements:
| Architecture | State Size | Example (d=512, h=8) |
|---|---|---|
| Transformer (KV cache) | O(n×d) | n × 512 |
| Linear RNN (vector) | O(d) | 512 |
| Linear RNN (matrix) | O(h×d²/h²) | 8 × 64 × 64 = 32K |
| RWKV-6/7 | O(h×d²/h²) | 8 × 64 × 64 = 32K |
| DeltaNet | O(h×d²/h²) | 8 × 64 × 64 = 32K |
For long sequences (n > 64), linear RNNs use dramatically less memory.
Benchmarked on A100 GPU, batch_size=1, d_model=512:
| Model | Tokens/sec (seq=1K) | Tokens/sec (seq=10K) | Tokens/sec (seq=100K) |
|---|---|---|---|
| Transformer | 2500 | 450 | OOM |
| Linear RNN (basic) | 3800 | 3700 | 3600 |
| Mamba | 4200 | 4100 | 4000 |
| RWKV-6 | 4500 | 4400 | 4300 |
Linear RNNs maintain constant throughput regardless of context length.
On 8×A100, WikiText-103, d_model=1024:
| Model | Tokens/sec | Memory (GB) | Wall Time (1 epoch) |
|---|---|---|---|
| Transformer | 45K | 72 | 8.2 hours |
| Linear RNN | 38K | 48 | 9.7 hours |
| Mamba | 52K | 42 | 7.1 hours |
| Mamba-2 | 88K | 38 | 4.2 hours |
Linear RNNs are competitive, with Mamba-2 being fastest due to hardware-aware implementation.
Performance on LRA tasks (accuracy %):
| Model | ListOps | Text | Retrieval | Image | Path-X | Avg |
|---|---|---|---|---|---|---|
| Transformer | 36.4 | 64.3 | 57.5 | 42.4 | 71.2 | 54.4 |
| Linear RNN (basic) | 38.1 | 62.8 | 79.2 | 41.3 | 68.5 | 58.0 |
| S4 | 58.3 | 76.3 | 87.8 | 88.1 | 86.4 | 79.4 |
| Mamba | 62.7 | 82.1 | 89.3 | 91.2 | 92.8 | 83.6 |
Modern linear RNN variants (S4, Mamba) significantly outperform standard transformers on long-range tasks.
Choose state size based on task:
# Vector states: simple tasks, short-term dependencies
state_dim = d_model # O(d)
# Matrix states: complex tasks, long-range dependencies
state_dim = (num_heads, head_dim, head_dim) # O(h × d²/h²)Rule of thumb: start with matrix states for general-purpose models.
Short convolutions should be small:
# Too small: insufficient local context
conv_size = 2 # ❌ Not recommended
# Good: captures 3-4 token window
conv_size = 3-4 # ✅ Recommended
# Too large: expensive, defeats purpose
conv_size = 16 # ❌ Use attention insteadTypical value: conv_size = 4 provides good local context.
Balance expressivity and efficiency:
# Small: efficient but limited capacity
expand = 1 # ❌ Too restrictive
# Good: standard setting
expand = 2 # ✅ Recommended
# Large: high capacity but expensive
expand = 4 # ✅ For large models onlyLarger models can use larger expansion factors.
Always use gating for non-linearity:
# Bad: no gating
output = recurrence(x)
# Good: multiplicative gating
x_main, z = split(project(x))
y = recurrence(x_main)
output = y * sigmoid(z) # ✅ Gated outputGating is essential for model expressivity.
Normalize before gating:
# Correct order
y = recurrence(x)
y = LayerNorm(y) # Normalize first
y = y * silu(gate) # Then gate
output = project(y)
# Wrong order (unstable)
y = recurrence(x)
y = y * silu(gate) # ❌ Gate first
y = LayerNorm(y) # Then normalizeNormalization before gating prevents instability.
Initialize parameters carefully:
# Decay parameters: start near 1 (long memory)
decay = nn.Parameter(torch.ones(dim) * 0.9)
# Projection matrices: Xavier/Kaiming
nn.init.xavier_uniform_(self.in_proj.weight)
nn.init.kaiming_uniform_(self.out_proj.weight)
# Bias: zeros
nn.init.zeros_(self.in_proj.bias)Good initialization is critical for training stability.
Properly handle state during generation:
# Initialize state once
state = model.init_state(batch_size=1, device='cuda')
# Generate tokens
generated = []
for _ in range(max_tokens):
# Single token forward
logits, state = model(input_ids, state)
# Sample next token
next_token = sample(logits[:, -1])
generated.append(next_token)
# Use next token as input
input_ids = next_token.unsqueeze(0)
# ✅ State is properly maintained across stepsNever recreate state during generation.
# ❌ Wrong: state shape doesn't match batch size
state = torch.zeros(1, dim) # batch=1
x = torch.randn(8, seq_len, dim) # batch=8
output, state = model(x, state) # Error!
# ✅ Correct: match batch sizes
state = torch.zeros(8, dim)
output, state = model(x, state)# ❌ Wrong: non-causal padding sees future
conv = nn.Conv1d(dim, dim, kernel_size=4, padding=2)
# ✅ Correct: causal padding
conv = nn.Conv1d(dim, dim, kernel_size=4, padding=3)
y = conv(x)[:, :, :seq_len] # Trim excess# ❌ Wrong: full convolution is expensive
conv = nn.Conv1d(512, 512, kernel_size=4) # 512×512×4 params
# ✅ Correct: depthwise convolution
conv = nn.Conv1d(512, 512, kernel_size=4, groups=512) # 512×4 params# ❌ Wrong: creating new state instead of updating
def forward(x, state):
state = torch.zeros_like(state) # Erases memory!
...
# ✅ Correct: update existing state
def forward(x, state):
state = decay * state + update # Maintains memory
...# ❌ Wrong: using recurrence during training
if self.training:
for t in range(seq_len): # Slow!
state = recurrence(state, x[:, t])
# ✅ Correct: use parallel method during training
if self.training:
output = parallel_scan(x) # Fast!
else:
output = recurrence(x, state) # Necessary for inferenceEfficient parallel computation of recurrences:
def parallel_scan(inputs, decay):
"""Compute recurrence in O(log n) parallel steps.
Args:
inputs: (batch, seq_len, dim)
decay: Decay factor (scalar or per-dim)
Returns:
outputs: Cumulative scan (batch, seq_len, dim)
"""
# Binary tree reduction
n = inputs.shape[1]
log_n = math.ceil(math.log2(n))
# Upsweep: compute partial products
for d in range(log_n):
stride = 2 ** (d + 1)
for i in range(stride - 1, n, stride):
inputs[:, i] = (
decay * inputs[:, i - stride // 2] +
inputs[:, i]
)
# Downsweep: propagate results
for d in range(log_n - 1, -1, -1):
stride = 2 ** (d + 1)
for i in range(3 * stride // 2 - 1, n, stride):
inputs[:, i] = (
decay * inputs[:, i - stride // 2] +
inputs[:, i]
)
return inputsBalance parallelism and memory:
def chunk_wise_recurrence(x, state, chunk_size=64):
"""Process sequence in chunks.
Args:
x: Input (batch, seq_len, dim)
state: Initial state
chunk_size: Chunk size for processing
"""
outputs = []
for i in range(0, x.shape[1], chunk_size):
chunk = x[:, i:i+chunk_size]
# Process chunk in parallel
chunk_output, state = process_chunk(chunk, state)
outputs.append(chunk_output)
return torch.cat(outputs, dim=1), stateCombine forward and backward passes:
class BidirectionalLinearRNN(LinearRNN):
"""Bi-directional linear RNN."""
def __init__(self, dim, expand=2):
super().__init__(dim, expand)
self.backward_rnn = LinearRNN(dim, expand)
def forward(self, x, state_fwd=None, state_bwd=None):
# Forward pass
y_fwd, state_fwd = super().forward(x, state_fwd)
# Backward pass
x_rev = torch.flip(x, [1])
y_bwd, state_bwd = self.backward_rnn(x_rev, state_bwd)
y_bwd = torch.flip(y_bwd, [1])
# Combine
output = y_fwd + y_bwd
return output, (state_fwd, state_bwd)Use multiple decay rates:
class MultiScaleRNN(LinearRNN):
"""RNN with multiple time scales."""
def __init__(self, dim, expand=2, num_scales=4):
super().__init__(dim, expand)
self.num_scales = num_scales
# Different decay rates per scale
self.decays = nn.Parameter(
torch.linspace(0.1, 0.9, num_scales)
)
def recurrent_forward(self, x, states):
if states is None:
states = [None] * self.num_scales
outputs = []
new_states = []
for i, decay in enumerate(self.decays):
# Process with this time scale
y, state = self.scale_recurrence(x, states[i], decay)
outputs.append(y)
new_states.append(state)
# Combine scales
output = sum(outputs) / self.num_scales
return output, new_states-
Linear Recurrent Units (LRU)
- Orvieto et al., "Resurrecting Recurrent Neural Networks for Long Sequences", ICML 2023
- https://arxiv.org/abs/2303.06349
- Analyzes why and how linear RNNs can be effective
-
S4: Structured State Spaces
- Gu et al., "Efficiently Modeling Long Sequences with Structured State Spaces", ICLR 2022
- https://arxiv.org/abs/2111.00396
- Foundational work on efficient SSM computation
-
Parallel Scan for RNNs
- Martin & Cundy, "Parallelizing Linear Recurrent Neural Networks over Sequence Length", ICLR 2018
- https://arxiv.org/abs/1709.04057
- Describes parallel scan algorithm
-
Mamba: Selective State Spaces
- Gu & Dao, "Mamba: Linear-Time Sequence Modeling with Selective State Spaces", 2023
- https://arxiv.org/abs/2312.00752
- Input-dependent parameters for better expressivity
-
RetNet: Retentive Networks
- Sun et al., "Retentive Network: A Successor to Transformer for Large Language Models", 2023
- https://arxiv.org/abs/2307.08621
- Multi-scale retention mechanism
-
RWKV: Receptance Weighted Key Value
- Peng et al., "RWKV: Reinventing RNNs for the Transformer Era", EMNLP 2023
- https://arxiv.org/abs/2305.13048
- Demonstrates competitive LLM performance
-
Linear Attention and SSMs
- Dao & Gu, "Transformers are SSMs: Generalized Models and Efficient Algorithms through Structured State Space Duality", 2024
- https://arxiv.org/abs/2405.21060
- Unifies linear attention and SSMs theoretically
-
Expressivity of Linear RNNs
- Merrill et al., "Provable Limitations of Acquiring Meaning from Ungrounded Form: What will Future Language Models Understand?", TACL 2021
- https://arxiv.org/abs/2104.10809
- Theoretical limits of linear recurrence
-
Annotated S4
- Rush & Karamcheti, "The Annotated S4", 2022
- https://srush.github.io/annotated-s4/
- Line-by-line implementation walkthrough
-
Mamba Implementation Notes
- Dao, "Mamba: The Hard Parts", 2024
- https://github.com/state-spaces/mamba
- Hardware-aware implementation details
Linear RNNs provide the foundational infrastructure for modern efficient sequence models. By understanding the base architecture, recurrence patterns, and implementation techniques covered in this document, you can:
- Implement custom linear RNN variants for specific tasks
- Debug and optimize existing implementations
- Make informed choices about which SSM variant to use
- Build new hybrid architectures combining multiple techniques
The key insight is the duality between recurrent and parallel computation: the same model can be efficiently trained (parallel) and efficiently deployed (recurrent). This makes linear RNNs practical for real-world applications with long sequences.
For specific use cases, consider the specialized variants:
- S4/S4D: Long-range dependencies with theoretical guarantees
- Mamba/Mamba-2: General-purpose language modeling
- RetNet: Multi-scale temporal patterns
- RWKV: Efficient LLMs with true O(1) inference
- DeltaNet: Associative memory and retrieval
All these variants build on the linear RNN foundation described in this document.