S4 (Structured State Space sequence model) revolutionized sequence modeling by showing that continuous-time state space models, when properly initialized and parameterized, can effectively model long-range dependencies with linear complexity. Before S4, RNNs struggled with long sequences due to vanishing gradients, while Transformers required quadratic complexity.
| Aspect | Transformers | S4 |
|---|---|---|
| Training complexity | O(n² d) | O(n log n · d) |
| Inference complexity | O(n² d) | O(n d) or O(d) per step |
| Long-range modeling | Requires positional encodings | Built-in through HiPPO |
| Memory usage | High (attention matrix) | Low (state vector) |
| Parallelization | Excellent | Good |
S4 bridges the gap between RNNs (efficient inference) and Transformers (parallelizable training).
State space models describe dynamical systems through differential equations:
Continuous time:
dx/dt = Ax(t) + Bu(t)
y(t) = Cx(t) + Du(t)
Discrete time (after discretization with step Δ):
x_k = Ā x_{k-1} + B̄ u_k
y_k = C x_k + D u_k
where:
x(t) ∈ ℝᴺis the hidden state (N-dimensional)u(t) ∈ ℝis the input (scalar or vector)y(t) ∈ ℝis the outputA ∈ ℝᴺˣᴺis the state transition matrixB ∈ ℝᴺis the input matrixC ∈ ℝᴺis the output matrixD ∈ ℝis the feedthrough (skip connection)
The key innovation of S4 is using HiPPO (High-order Polynomial Projection Operators) to initialize the state matrix A. HiPPO matrices are derived from the optimal solution to online function approximation:
Given a continuous signal f(t), compress its history into N coefficients that minimize reconstruction error. The HiPPO-LegS (Legendre, scaled) measure gives:
A_{nk} = -(2n+1)^{1/2} (2k+1)^{1/2} if n > k
A_{nn} = -(n+1) if n = k
A_{nk} = 0 if n < k
This initialization enables S4 to remember information over very long sequences (>10k tokens).
The continuous-time SSM is:
ẋ(t) = Ax(t) + Bu(t)
y(t) = Cx(t) + Du(t)
To make the system trainable with discrete sequences, we discretize using zero-order hold (ZOH):
x_k = exp(ΔA) x_{k-1} + (∫₀^Δ exp(sA) ds) B u_k
This simplifies to:
Ā = exp(ΔA)
B̄ = (Ā - I) A⁻¹ B
x_k = Ā x_{k-1} + B̄ u_k
y_k = C x_k + D u_k
The discrete SSM defines a convolution kernel:
K = {CB̄, CĀB̄, C²B̄, ..., CĀᴸB̄}
The output is then:
y = K * u = ∑_{i=0}^{L-1} K_i u_{L-i}
This allows parallel computation via FFT in O(L log L) time.
Computing the kernel naively is O(N²L). S4 uses the DPLR (Diagonal Plus Low-Rank) structure:
A = Λ + PQ*
where:
- Λ is a complex diagonal matrix (N×N)
- P, Q are low-rank factors (N×r, typically r=1)
This decomposition enables efficient kernel computation via the Woodbury identity and FFT.
Think of S4 as a learned filter bank that processes sequences:
- State as Memory: The N-dimensional state x compresses the sequence history
- HiPPO as Initialization: Tells the model how to compress effectively (like Legendre polynomials)
- Convolution Mode (Training): Process entire sequence in parallel via FFT
- Recurrence Mode (Inference): Update state one step at a time with O(1) complexity
The magic is that the same model can:
- Train fast (parallel FFT convolution)
- Infer fast (sequential O(1) updates)
- Remember long-range dependencies (HiPPO initialization)
class S4Block:
def __init__(self, d_model, d_state=64):
# 1. Initialize HiPPO matrix A
self.A = hippo_initializer(d_state) # (N, N)
# 2. Decompose into DPLR form
self.Lambda, self.P, self.Q = dplr_decomposition(self.A)
# 3. Learnable input/output projections
self.B = Parameter(randn(d_model, d_state))
self.C = Parameter(randn(d_model, d_state))
# 4. Discretization step (learnable)
self.log_dt = Parameter(randn(d_model))
# 5. Skip connection
self.D = Parameter(randn(d_model))def forward_conv(self, u):
# u: (batch, L, d_model)
L = u.shape[1]
# 1. Generate SSM convolution kernel
K = self.compute_kernel(L) # (d_model, L)
# 2. FFT convolution
u_fft = fft(u, dim=1) # O(L log L)
K_fft = fft(K, dim=1)
y_fft = u_fft * K_fft
y = ifft(y_fft, dim=1) # O(L log L)
# 3. Add skip connection
y = y + self.D * u
return ydef forward_recurrent(self, u, state):
# u: (batch, d_model) - single timestep
# state: (batch, d_model, d_state)
# 1. Discretize
dt = exp(self.log_dt)
A_bar = exp(dt * self.Lambda) # Complex diagonal
B_bar = dt * self.B
# 2. State update (O(N) per step)
state = A_bar * state + B_bar * u.unsqueeze(-1)
# 3. Output
y = (self.C * state).sum(dim=-1) + self.D * u
return y, stateThe kernel K[i] = C Āⁱ B̄ is computed via DPLR structure:
def compute_kernel(self, L):
dt = exp(self.log_dt)
# Discretize
A_bar = exp(dt * self.Lambda) # (d_model, N)
B_bar = dt * self.B
# Vandermonde: powers of A_bar
# K[i] = C * A_bar^i * B_bar
powers = A_bar.unsqueeze(-1) ** arange(L).unsqueeze(0).unsqueeze(0)
# powers: (d_model, N, L)
# K = C @ diag(powers) @ B
K = einsum('dn,dnl->dl', self.C * B_bar, powers)
return K.real # Take real partSee Nexus/nexus/components/ssm/s4.py for full implementation.
-
hippo_initializer(N): Constructs HiPPO-LegS matrix
P = sqrt(1 + 2 * arange(N)) A = P.unsqueeze(1) * P.unsqueeze(0) A = tril(A) - diag(arange(N) + 1) return -A
-
dplr_decomposition(A): Extracts diagonal and low-rank components
- Uses eigendecomposition for numerical stability
- Returns complex-valued Lambda, P, Q
-
S4Kernel: Manages DPLR parameters and kernel generation
- Stores Lambda, P, B, C in complex form (real/imag parts)
- Computes kernel via Vandermonde in frequency domain
-
S4Layer: Wraps kernel with normalization and activation
- Supports bidirectional processing (optional)
- GELU activation and dropout
-
S4Block: Full block with residual connection and FFN
- Pre-normalization
- Residual connections around S4 and FFN
Store complex parameters as separate real/imaginary parts:
self.Lambda_real = Parameter(Lambda.real)
self.Lambda_imag = Parameter(Lambda.imag)
# Reconstruct when needed
Lambda = complex(self.Lambda_real, self.Lambda_imag)PyTorch's complex autograd is slow; this is faster.
Store log(dt) instead of dt to ensure positivity:
self.log_dt = Parameter(log(dt_init))
dt = exp(self.log_dt) # Always positiveFor many tasks, restricting A to be purely diagonal (no low-rank correction) works well:
- Faster computation (no Woodbury identity)
- Less memory
- See S4D paper for details
Cache discretized parameters during inference:
@torch.no_grad()
def cache_discretization(self):
dt = exp(self.log_dt)
self.A_bar_cached = exp(dt * self.Lambda)
self.B_bar_cached = dt * self.BPre-compute kernels for common sequence lengths (64, 128, 256, etc.):
kernel_cache = {
64: self.compute_kernel(64),
128: self.compute_kernel(128),
# ...
}S4 achieves state-of-the-art on LRA, a benchmark for long-range dependencies:
| Task | Sequence Length | S4 Accuracy | Transformer | LSTM |
|---|---|---|---|---|
| ListOps | 2K | 59.6% | 36.4% | 18.0% |
| Text | 4K | 86.8% | 64.3% | 62.9% |
| Retrieval | 4K | 90.9% | 57.5% | 50.0% |
| Image | 1K | 88.7% | 42.4% | 60.2% |
| Pathfinder | 1K | 86.1% | 71.4% | 62.2% |
| Path-X | 16K | 88.0% | fail | fail |
S4 is the only model to succeed on Path-X (16K length).
S4 matches or exceeds Transformer performance on speech:
- SC09 (Speech Commands): 96.5% accuracy (vs 96.2% Transformer)
- Lower latency in streaming scenarios
- Better handling of variable-length audio
Training time (relative to Transformer):
- Sequence length 1K: 0.8x (slightly faster)
- Sequence length 4K: 0.4x (2.5x faster)
- Sequence length 16K: 0.15x (6.7x faster)
Memory usage (relative to Transformer):
- 1K: 0.6x
- 4K: 0.15x
- 16K: 0.04x (25x less memory!)
S4 scales well to large models:
- S4-GPT (125M params): Competitive with GPT-2 on WikiText-103
- S4-BERT (110M params): 95% of BERT-base performance on GLUE
- Linear complexity enables training on longer contexts
Problem: Using Euler method instead of ZOH:
# BAD: Euler discretization
A_bar = I + dt * A
B_bar = dt * BSolution: Use proper ZOH:
# GOOD: Zero-order hold
A_bar = exp(dt * A)
B_bar = (A_bar - I) @ inv(A) @ BZOH is crucial for numerical stability with large dt.
Problem: Trying to work with real-valued A only.
Solution: HiPPO eigenvalues are complex. Must use complex arithmetic or carefully manage real/imaginary parts.
Problem: Standard convolution is not causal (sees future):
# BAD: Non-causal
y = F.conv1d(u, kernel)Solution: Pad and truncate for causality:
# GOOD: Causal convolution
u_padded = F.pad(u, (L-1, 0))
y = F.conv1d(u_padded, kernel)[:, :, :L]Problem: Omitting the D parameter.
Solution: D is crucial for modeling the current input:
y = SSM(u) + D * u # Always include skip connectionProblem: Setting N too small (e.g., N=8).
Solution: Use N >= 64 for most tasks. HiPPO needs sufficient dimensions to capture history.
Problem: Random initialization of dt.
Solution: Initialize dt to reasonable range:
log_dt = rand(d_model) * (log(0.1) - log(0.001)) + log(0.001)
# dt in range [0.001, 0.1]Problem: Kernel computation overflows for very long L.
Solution:
- Clip dt to prevent exp(dt * A) overflow
- Use mixed precision (fp32 for SSM, fp16 for other layers)
- Consider chunking very long sequences
def init_s4_layer(layer):
# 1. Initialize B, C with proper scaling
nn.init.xavier_uniform_(layer.B_real)
nn.init.xavier_uniform_(layer.C_real)
# 2. Initialize dt in log-space
dt_min, dt_max = 0.001, 0.1
log_dt = rand(d_model) * (log(dt_max) - log(dt_min)) + log(dt_min)
layer.log_dt.data = log_dt
# 3. Initialize D close to zero
nn.init.zeros_(layer.D)
# 4. Lambda, P, Q from HiPPO (done in constructor)-
S4 (ICLR 2022)
- Gu, Goel, Ré. "Efficiently Modeling Long Sequences with Structured State Spaces"
- https://arxiv.org/abs/2111.00396
- Introduces S4, HiPPO, DPLR
-
HiPPO (NeurIPS 2020)
- Gu, Dao, Ermon, Rudra, Ré. "HiPPO: Recurrent Memory with Optimal Polynomial Projections"
- https://arxiv.org/abs/2008.07669
- Mathematical foundation for state initialization
-
LSSL (NeurIPS 2021)
- Gu, Johnson, Goel, Saab, Dao, Rudra, Ré. "Combining Recurrent, Convolutional, and Continuous-time Models with Linear State-Space Layers"
- https://arxiv.org/abs/2110.13985
- Precursor to S4
-
S4D (NeurIPS 2022)
- Gu, Gupta, Goel, Ré. "On the Parameterization and Initialization of Diagonal State Space Models"
- https://arxiv.org/abs/2206.11893
- Diagonal simplification
-
DSS (ICML 2022)
- Gupta, Gu, Berant. "Diagonal State Spaces are as Effective as Structured State Spaces"
- https://arxiv.org/abs/2203.14343
- Further analysis of diagonal restriction
-
S4ND (NeurIPS 2022)
- Nguyen, Goel, Rudra, Ré. "S4ND: Modeling Images and Videos as Multidimensional Signals Using State Spaces"
- https://arxiv.org/abs/2210.06583
- Extension to images/video
-
The Annotated S4
- Sasha Rush & Sidd Karamcheti
- https://srush.github.io/annotated-s4/
- Line-by-line implementation guide
-
S4 Explained
- Hugging Face Blog
- https://huggingface.co/blog/lbourdois/get-on-the-ssm-train
- Intuitive explanation
When implementing S4 from scratch:
- HiPPO initialization for A matrix
- DPLR decomposition (Lambda, P, Q)
- Complex number handling (real/imag split)
- Proper ZOH discretization (not Euler)
- Kernel computation via Vandermonde
- FFT-based convolution for training
- Recurrent mode for inference
- Skip connection D
- Log-space dt parameterization
- Causal masking in convolution
- Numerical stability (clipping, mixed precision)
- Bidirectional support (optional)
- Kernel caching for common lengths
For implementation reference, see Nexus/nexus/components/ssm/s4.py