Skip to content

Latest commit

 

History

History
504 lines (363 loc) · 13.7 KB

File metadata and controls

504 lines (363 loc) · 13.7 KB

S4: Structured State Spaces for Sequence Modeling

Overview & Motivation

S4 (Structured State Space sequence model) revolutionized sequence modeling by showing that continuous-time state space models, when properly initialized and parameterized, can effectively model long-range dependencies with linear complexity. Before S4, RNNs struggled with long sequences due to vanishing gradients, while Transformers required quadratic complexity.

Why S4 vs Transformers?

Aspect Transformers S4
Training complexity O(n² d) O(n log n · d)
Inference complexity O(n² d) O(n d) or O(d) per step
Long-range modeling Requires positional encodings Built-in through HiPPO
Memory usage High (attention matrix) Low (state vector)
Parallelization Excellent Good

S4 bridges the gap between RNNs (efficient inference) and Transformers (parallelizable training).

Theoretical Background

State Space Fundamentals

State space models describe dynamical systems through differential equations:

Continuous time:
  dx/dt = Ax(t) + Bu(t)
  y(t) = Cx(t) + Du(t)

Discrete time (after discretization with step Δ):
  x_k = Ā x_{k-1} + B̄ u_k
  y_k = C x_k + D u_k

where:

  • x(t) ∈ ℝᴺ is the hidden state (N-dimensional)
  • u(t) ∈ ℝ is the input (scalar or vector)
  • y(t) ∈ ℝ is the output
  • A ∈ ℝᴺˣᴺ is the state transition matrix
  • B ∈ ℝᴺ is the input matrix
  • C ∈ ℝᴺ is the output matrix
  • D ∈ ℝ is the feedthrough (skip connection)

HiPPO Theory

The key innovation of S4 is using HiPPO (High-order Polynomial Projection Operators) to initialize the state matrix A. HiPPO matrices are derived from the optimal solution to online function approximation:

Given a continuous signal f(t), compress its history into N coefficients that minimize reconstruction error. The HiPPO-LegS (Legendre, scaled) measure gives:

A_{nk} = -(2n+1)^{1/2} (2k+1)^{1/2}  if n > k
A_{nn} = -(n+1)                       if n = k
A_{nk} = 0                             if n < k

This initialization enables S4 to remember information over very long sequences (>10k tokens).

Mathematical Formulation

1. Continuous-Time System

The continuous-time SSM is:

ẋ(t) = Ax(t) + Bu(t)
y(t) = Cx(t) + Du(t)

2. Discretization (Zero-Order Hold)

To make the system trainable with discrete sequences, we discretize using zero-order hold (ZOH):

x_k = exp(ΔA) x_{k-1} + (∫₀^Δ exp(sA) ds) B u_k

This simplifies to:

Ā = exp(ΔA)
B̄ = (Ā - I) A⁻¹ B

x_k = Ā x_{k-1} + B̄ u_k
y_k = C x_k + D u_k

3. Convolution Representation

The discrete SSM defines a convolution kernel:

K = {CB̄, CĀB̄, C²B̄, ..., CĀᴸB̄}

The output is then:

y = K * u = ∑_{i=0}^{L-1} K_i u_{L-i}

This allows parallel computation via FFT in O(L log L) time.

4. DPLR Decomposition

Computing the kernel naively is O(N²L). S4 uses the DPLR (Diagonal Plus Low-Rank) structure:

A = Λ + PQ*

where:

  • Λ is a complex diagonal matrix (N×N)
  • P, Q are low-rank factors (N×r, typically r=1)

This decomposition enables efficient kernel computation via the Woodbury identity and FFT.

High-Level Intuition

Think of S4 as a learned filter bank that processes sequences:

  1. State as Memory: The N-dimensional state x compresses the sequence history
  2. HiPPO as Initialization: Tells the model how to compress effectively (like Legendre polynomials)
  3. Convolution Mode (Training): Process entire sequence in parallel via FFT
  4. Recurrence Mode (Inference): Update state one step at a time with O(1) complexity

The magic is that the same model can:

  • Train fast (parallel FFT convolution)
  • Infer fast (sequential O(1) updates)
  • Remember long-range dependencies (HiPPO initialization)

Implementation Details

Architecture Components

class S4Block:
    def __init__(self, d_model, d_state=64):
        # 1. Initialize HiPPO matrix A
        self.A = hippo_initializer(d_state)  # (N, N)

        # 2. Decompose into DPLR form
        self.Lambda, self.P, self.Q = dplr_decomposition(self.A)

        # 3. Learnable input/output projections
        self.B = Parameter(randn(d_model, d_state))
        self.C = Parameter(randn(d_model, d_state))

        # 4. Discretization step (learnable)
        self.log_dt = Parameter(randn(d_model))

        # 5. Skip connection
        self.D = Parameter(randn(d_model))

Training Mode: Convolution

def forward_conv(self, u):
    # u: (batch, L, d_model)
    L = u.shape[1]

    # 1. Generate SSM convolution kernel
    K = self.compute_kernel(L)  # (d_model, L)

    # 2. FFT convolution
    u_fft = fft(u, dim=1)        # O(L log L)
    K_fft = fft(K, dim=1)
    y_fft = u_fft * K_fft
    y = ifft(y_fft, dim=1)       # O(L log L)

    # 3. Add skip connection
    y = y + self.D * u

    return y

Inference Mode: Recurrence

def forward_recurrent(self, u, state):
    # u: (batch, d_model) - single timestep
    # state: (batch, d_model, d_state)

    # 1. Discretize
    dt = exp(self.log_dt)
    A_bar = exp(dt * self.Lambda)  # Complex diagonal
    B_bar = dt * self.B

    # 2. State update (O(N) per step)
    state = A_bar * state + B_bar * u.unsqueeze(-1)

    # 3. Output
    y = (self.C * state).sum(dim=-1) + self.D * u

    return y, state

Efficient Kernel Computation

The kernel K[i] = C Āⁱ B̄ is computed via DPLR structure:

def compute_kernel(self, L):
    dt = exp(self.log_dt)

    # Discretize
    A_bar = exp(dt * self.Lambda)  # (d_model, N)
    B_bar = dt * self.B

    # Vandermonde: powers of A_bar
    # K[i] = C * A_bar^i * B_bar
    powers = A_bar.unsqueeze(-1) ** arange(L).unsqueeze(0).unsqueeze(0)
    # powers: (d_model, N, L)

    # K = C @ diag(powers) @ B
    K = einsum('dn,dnl->dl', self.C * B_bar, powers)

    return K.real  # Take real part

Code Walkthrough

See Nexus/nexus/components/ssm/s4.py for full implementation.

Key Functions

  1. hippo_initializer(N): Constructs HiPPO-LegS matrix

    P = sqrt(1 + 2 * arange(N))
    A = P.unsqueeze(1) * P.unsqueeze(0)
    A = tril(A) - diag(arange(N) + 1)
    return -A
  2. dplr_decomposition(A): Extracts diagonal and low-rank components

    • Uses eigendecomposition for numerical stability
    • Returns complex-valued Lambda, P, Q
  3. S4Kernel: Manages DPLR parameters and kernel generation

    • Stores Lambda, P, B, C in complex form (real/imag parts)
    • Computes kernel via Vandermonde in frequency domain
  4. S4Layer: Wraps kernel with normalization and activation

    • Supports bidirectional processing (optional)
    • GELU activation and dropout
  5. S4Block: Full block with residual connection and FFN

    • Pre-normalization
    • Residual connections around S4 and FFN

Optimization Tricks

1. Complex Number Handling

Store complex parameters as separate real/imaginary parts:

self.Lambda_real = Parameter(Lambda.real)
self.Lambda_imag = Parameter(Lambda.imag)

# Reconstruct when needed
Lambda = complex(self.Lambda_real, self.Lambda_imag)

PyTorch's complex autograd is slow; this is faster.

2. Log-Space Discretization

Store log(dt) instead of dt to ensure positivity:

self.log_dt = Parameter(log(dt_init))
dt = exp(self.log_dt)  # Always positive

3. Diagonal A Simplification (S4D)

For many tasks, restricting A to be purely diagonal (no low-rank correction) works well:

  • Faster computation (no Woodbury identity)
  • Less memory
  • See S4D paper for details

4. Caching for Inference

Cache discretized parameters during inference:

@torch.no_grad()
def cache_discretization(self):
    dt = exp(self.log_dt)
    self.A_bar_cached = exp(dt * self.Lambda)
    self.B_bar_cached = dt * self.B

5. Sequence Length Binning

Pre-compute kernels for common sequence lengths (64, 128, 256, etc.):

kernel_cache = {
    64: self.compute_kernel(64),
    128: self.compute_kernel(128),
    # ...
}

Experiments & Results

Long Range Arena (LRA) Benchmark

S4 achieves state-of-the-art on LRA, a benchmark for long-range dependencies:

Task Sequence Length S4 Accuracy Transformer LSTM
ListOps 2K 59.6% 36.4% 18.0%
Text 4K 86.8% 64.3% 62.9%
Retrieval 4K 90.9% 57.5% 50.0%
Image 1K 88.7% 42.4% 60.2%
Pathfinder 1K 86.1% 71.4% 62.2%
Path-X 16K 88.0% fail fail

S4 is the only model to succeed on Path-X (16K length).

Speech Recognition

S4 matches or exceeds Transformer performance on speech:

  • SC09 (Speech Commands): 96.5% accuracy (vs 96.2% Transformer)
  • Lower latency in streaming scenarios
  • Better handling of variable-length audio

Computational Efficiency

Training time (relative to Transformer):

  • Sequence length 1K: 0.8x (slightly faster)
  • Sequence length 4K: 0.4x (2.5x faster)
  • Sequence length 16K: 0.15x (6.7x faster)

Memory usage (relative to Transformer):

  • 1K: 0.6x
  • 4K: 0.15x
  • 16K: 0.04x (25x less memory!)

Scaling to Large Models

S4 scales well to large models:

  • S4-GPT (125M params): Competitive with GPT-2 on WikiText-103
  • S4-BERT (110M params): 95% of BERT-base performance on GLUE
  • Linear complexity enables training on longer contexts

Common Pitfalls

1. Wrong Discretization Method

Problem: Using Euler method instead of ZOH:

# BAD: Euler discretization
A_bar = I + dt * A
B_bar = dt * B

Solution: Use proper ZOH:

# GOOD: Zero-order hold
A_bar = exp(dt * A)
B_bar = (A_bar - I) @ inv(A) @ B

ZOH is crucial for numerical stability with large dt.

2. Not Using Complex Numbers

Problem: Trying to work with real-valued A only.

Solution: HiPPO eigenvalues are complex. Must use complex arithmetic or carefully manage real/imaginary parts.

3. Incorrect Causal Masking in Convolution

Problem: Standard convolution is not causal (sees future):

# BAD: Non-causal
y = F.conv1d(u, kernel)

Solution: Pad and truncate for causality:

# GOOD: Causal convolution
u_padded = F.pad(u, (L-1, 0))
y = F.conv1d(u_padded, kernel)[:, :, :L]

4. Forgetting the Skip Connection D

Problem: Omitting the D parameter.

Solution: D is crucial for modeling the current input:

y = SSM(u) + D * u  # Always include skip connection

5. Using Too Small State Dimension

Problem: Setting N too small (e.g., N=8).

Solution: Use N >= 64 for most tasks. HiPPO needs sufficient dimensions to capture history.

6. Not Initializing log_dt Properly

Problem: Random initialization of dt.

Solution: Initialize dt to reasonable range:

log_dt = rand(d_model) * (log(0.1) - log(0.001)) + log(0.001)
# dt in range [0.001, 0.1]

7. Numerical Instability with Large Sequences

Problem: Kernel computation overflows for very long L.

Solution:

  • Clip dt to prevent exp(dt * A) overflow
  • Use mixed precision (fp32 for SSM, fp16 for other layers)
  • Consider chunking very long sequences

Initialization Best Practices

def init_s4_layer(layer):
    # 1. Initialize B, C with proper scaling
    nn.init.xavier_uniform_(layer.B_real)
    nn.init.xavier_uniform_(layer.C_real)

    # 2. Initialize dt in log-space
    dt_min, dt_max = 0.001, 0.1
    log_dt = rand(d_model) * (log(dt_max) - log(dt_min)) + log(dt_min)
    layer.log_dt.data = log_dt

    # 3. Initialize D close to zero
    nn.init.zeros_(layer.D)

    # 4. Lambda, P, Q from HiPPO (done in constructor)

References

Primary Papers

  1. S4 (ICLR 2022)

  2. HiPPO (NeurIPS 2020)

    • Gu, Dao, Ermon, Rudra, Ré. "HiPPO: Recurrent Memory with Optimal Polynomial Projections"
    • https://arxiv.org/abs/2008.07669
    • Mathematical foundation for state initialization
  3. LSSL (NeurIPS 2021)

    • Gu, Johnson, Goel, Saab, Dao, Rudra, Ré. "Combining Recurrent, Convolutional, and Continuous-time Models with Linear State-Space Layers"
    • https://arxiv.org/abs/2110.13985
    • Precursor to S4

Related Work

  1. S4D (NeurIPS 2022)

  2. DSS (ICML 2022)

    • Gupta, Gu, Berant. "Diagonal State Spaces are as Effective as Structured State Spaces"
    • https://arxiv.org/abs/2203.14343
    • Further analysis of diagonal restriction
  3. S4ND (NeurIPS 2022)

    • Nguyen, Goel, Rudra, Ré. "S4ND: Modeling Images and Videos as Multidimensional Signals Using State Spaces"
    • https://arxiv.org/abs/2210.06583
    • Extension to images/video

Tutorials and Blogs

  1. The Annotated S4

  2. S4 Explained

Implementation Checklist

When implementing S4 from scratch:

  • HiPPO initialization for A matrix
  • DPLR decomposition (Lambda, P, Q)
  • Complex number handling (real/imag split)
  • Proper ZOH discretization (not Euler)
  • Kernel computation via Vandermonde
  • FFT-based convolution for training
  • Recurrent mode for inference
  • Skip connection D
  • Log-space dt parameterization
  • Causal masking in convolution
  • Numerical stability (clipping, mixed precision)
  • Bidirectional support (optional)
  • Kernel caching for common lengths

For implementation reference, see Nexus/nexus/components/ssm/s4.py