Skip to content

Commit aec8af5

Browse files
Added S4 and removed S6
1 parent d4d1827 commit aec8af5

File tree

3 files changed

+309
-298
lines changed

3 files changed

+309
-298
lines changed

models/S4.py

Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
"""
2+
Implementation of the S4 model taken from https://github.com/state-spaces/s4
3+
"""
4+
5+
import math
6+
7+
import torch
8+
import torch.nn as nn
9+
import torch.nn.functional as F
10+
from einops import rearrange, repeat
11+
12+
13+
class DropoutNd(nn.Module):
14+
def __init__(self, p: float = 0.5, tie=True, transposed=True):
15+
"""
16+
tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d)
17+
"""
18+
super().__init__()
19+
if p < 0 or p >= 1:
20+
raise ValueError(
21+
"dropout probability has to be in [0, 1), " "but got {}".format(p)
22+
)
23+
self.p = p
24+
self.tie = tie
25+
self.transposed = transposed
26+
self.binomial = torch.distributions.binomial.Binomial(probs=1 - self.p)
27+
28+
def forward(self, X):
29+
"""X: (batch, dim, lengths...)."""
30+
if self.training:
31+
if not self.transposed:
32+
X = rearrange(X, "b ... d -> b d ...")
33+
# binomial = torch.distributions.binomial.Binomial(probs=1-self.p) # This is incredibly slow because of CPU -> GPU copying
34+
mask_shape = X.shape[:2] + (1,) * (X.ndim - 2) if self.tie else X.shape
35+
# mask = self.binomial.sample(mask_shape)
36+
mask = torch.rand(*mask_shape, device=X.device) < 1.0 - self.p
37+
X = X * mask * (1.0 / (1 - self.p))
38+
if not self.transposed:
39+
X = rearrange(X, "b d ... -> b ... d")
40+
return X
41+
return X
42+
43+
44+
class S4DKernel(nn.Module):
45+
"""Generate convolution kernel from diagonal SSM parameters."""
46+
47+
def __init__(self, d_model, N=64, dt_min=0.001, dt_max=0.1, lr=None):
48+
super().__init__()
49+
# Generate dt
50+
H = d_model
51+
log_dt = torch.rand(H) * (math.log(dt_max) - math.log(dt_min)) + math.log(
52+
dt_min
53+
)
54+
55+
C = torch.randn(H, N // 2, dtype=torch.cfloat)
56+
self.C = nn.Parameter(torch.view_as_real(C))
57+
self.register("log_dt", log_dt, lr)
58+
59+
log_A_real = torch.log(0.5 * torch.ones(H, N // 2))
60+
A_imag = math.pi * repeat(torch.arange(N // 2), "n -> h n", h=H)
61+
self.register("log_A_real", log_A_real, lr)
62+
self.register("A_imag", A_imag, lr)
63+
64+
def forward(self, L):
65+
"""
66+
returns: (..., c, L) where c is number of channels (default 1)
67+
"""
68+
69+
# Materialize parameters
70+
dt = torch.exp(self.log_dt) # (H)
71+
C = torch.view_as_complex(self.C) # (H N)
72+
A = -torch.exp(self.log_A_real) + 1j * self.A_imag # (H N)
73+
74+
# Vandermonde multiplication
75+
dtA = A * dt.unsqueeze(-1) # (H N)
76+
K = dtA.unsqueeze(-1) * torch.arange(L, device=A.device) # (H N L)
77+
C = C * (torch.exp(dtA) - 1.0) / A
78+
K = 2 * torch.einsum("hn, hnl -> hl", C, torch.exp(K)).real
79+
80+
return K
81+
82+
def register(self, name, tensor, lr=None):
83+
"""Register a tensor with a configurable learning rate and 0 weight decay"""
84+
85+
if lr == 0.0:
86+
self.register_buffer(name, tensor)
87+
else:
88+
self.register_parameter(name, nn.Parameter(tensor))
89+
90+
optim = {"weight_decay": 0.0}
91+
if lr is not None:
92+
optim["lr"] = lr
93+
setattr(getattr(self, name), "_optim", optim)
94+
95+
96+
class S4D(nn.Module):
97+
def __init__(
98+
self, d_model, d_state=64, dropout=0.0, transposed=True, **kernel_args
99+
):
100+
super().__init__()
101+
102+
self.h = d_model
103+
self.n = d_state
104+
self.d_output = self.h
105+
self.transposed = transposed
106+
107+
self.D = nn.Parameter(torch.randn(self.h))
108+
109+
# SSM Kernel
110+
self.kernel = S4DKernel(self.h, N=self.n, **kernel_args)
111+
112+
# Pointwise
113+
self.activation = nn.GELU()
114+
# dropout_fn = nn.Dropout2d # NOTE: bugged in PyTorch 1.11
115+
dropout_fn = DropoutNd
116+
self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity()
117+
118+
# position-wise output transform to mix features
119+
self.output_linear = nn.Sequential(
120+
nn.Conv1d(self.h, 2 * self.h, kernel_size=1),
121+
nn.GLU(dim=-2),
122+
)
123+
124+
def forward(self, u, **kwargs): # absorbs return_output and transformer src mask
125+
"""Input and output shape (B, H, L)"""
126+
u = u.transpose(-1, -2)
127+
L = u.size(-1)
128+
129+
# Compute SSM Kernel
130+
k = self.kernel(L=L) # (H L)
131+
132+
# Convolution
133+
k_f = torch.fft.rfft(k, n=2 * L) # (H L)
134+
u_f = torch.fft.rfft(u, n=2 * L) # (B H L)
135+
y = torch.fft.irfft(u_f * k_f, n=2 * L)[..., :L] # (B H L)
136+
137+
# Compute D term in state space equation - essentially a skip connection
138+
y = y + u * self.D.unsqueeze(-1)
139+
140+
y = self.dropout(self.activation(y))
141+
y = self.output_linear(y)
142+
if not self.transposed:
143+
y = y.transpose(-1, -2)
144+
return y
145+
146+
147+
class S4Block(nn.Module):
148+
"""
149+
A single S4 block that applies:
150+
1. S4D module
151+
2. (Optionally) a linear layer + GLU activation,
152+
3. Residual connection
153+
4. Layer Normalization
154+
5. Dropout
155+
156+
Args:
157+
model_dim (int): Dimensionality of the model (d_model).
158+
dropout_rate (float): Probability of an element to be zeroed in Dropout.
159+
use_glu (bool): Whether to apply a Linear -> GLU stage after the residual.
160+
"""
161+
162+
def __init__(
163+
self, model_dim: int, dropout_rate: float = 0.1, use_glu: bool = False
164+
):
165+
super().__init__()
166+
self.s4 = S4D(d_model=model_dim)
167+
self.norm = nn.LayerNorm(model_dim)
168+
self.drop = nn.Dropout(p=dropout_rate)
169+
170+
self.use_glu = use_glu
171+
if self.use_glu:
172+
# The linear expands from model_dim to 2*model_dim
173+
# so that GLU can split it into two halves of model_dim each
174+
self.post_linear = nn.Linear(model_dim, 2 * model_dim)
175+
else:
176+
self.post_linear = None
177+
178+
def forward(self, x: torch.Tensor) -> torch.Tensor:
179+
"""
180+
Forward pass of the S4Block.
181+
182+
Args:
183+
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, model_dim).
184+
185+
Returns:
186+
torch.Tensor: Output tensor of the same shape (batch_size, seq_len, model_dim).
187+
"""
188+
189+
# S4 module
190+
y = self.s4(x)
191+
y = y.transpose(1, 2) # (batch_size, model_dim, seq_len)
192+
y = y + x
193+
194+
# Optional: Linear -> GLU
195+
if self.use_glu:
196+
# shape: (batch_size, seq_len, 2 * model_dim)
197+
y_glu = self.post_linear(y)
198+
# shape: (batch_size, seq_len, model_dim)
199+
y_glu = F.glu(y_glu, dim=-1)
200+
y = y + y_glu
201+
202+
# Layer Normalization
203+
y = self.norm(y)
204+
205+
# Dropout
206+
y = self.drop(y)
207+
208+
return y
209+
210+
211+
class StackedS4(nn.Module):
212+
"""
213+
A stack of multiple S4Blocks, preceded by an embedding layer
214+
and followed by a linear projection.
215+
216+
Args:
217+
num_blocks (int): Number of S4Blocks to stack.
218+
model_dim (int): Dimensionality of embeddings and S4 blocks.
219+
data_dim (int): Size of the vocabulary (if input is token IDs).
220+
label_dim (int): Output dimensionality (e.g., number of classes).
221+
dropout_rate (float): Dropout probability for each S4Block.
222+
use_glu (bool): If True, each block will include a Linear->GLU stage
223+
that preserves model_dim.
224+
second_embedding (bool): If True, the model will expect two input
225+
token IDs and use two separate embeddings.
226+
"""
227+
228+
def __init__(
229+
self,
230+
num_blocks: int,
231+
model_dim: int,
232+
data_dim: int,
233+
label_dim: int,
234+
dropout_rate: float = 0.1,
235+
use_glu: bool = False,
236+
second_embedding: bool = False,
237+
):
238+
super().__init__()
239+
240+
self.second_embedding = second_embedding
241+
embedding_dim = model_dim // 2 if second_embedding else model_dim
242+
self.embedding = nn.Embedding(data_dim, embedding_dim)
243+
if second_embedding:
244+
self.embedding2 = nn.Embedding(data_dim, embedding_dim)
245+
246+
# Create multiple S4Blocks
247+
self.blocks = nn.ModuleList(
248+
[
249+
S4Block(model_dim=model_dim, dropout_rate=dropout_rate, use_glu=use_glu)
250+
for _ in range(num_blocks)
251+
]
252+
)
253+
254+
# The final linear projection remains (model_dim -> label_dim)
255+
self.linear = nn.Linear(model_dim, label_dim)
256+
257+
def mask_grads(self):
258+
"""
259+
This method is included for consistency with other models.
260+
"""
261+
pass
262+
263+
def forward(self, x: torch.Tensor) -> torch.Tensor:
264+
"""
265+
Forward pass of StackedS4.
266+
267+
Args:
268+
x (torch.Tensor): Input tensor of shape (batch_size, seq_len)
269+
containing integer token IDs (if used with nn.Embedding).
270+
271+
Returns:
272+
torch.Tensor: Output tensor of shape (batch_size, seq_len, label_dim).
273+
If a single-vector output is desired (e.g. for classification),
274+
additional pooling or indexing may be required
275+
before the final linear layer or after its output.
276+
"""
277+
# Embedding: (batch_size, seq_len, model_dim)
278+
if not self.second_embedding:
279+
x = self.embedding(x)
280+
else:
281+
x = torch.cat(
282+
[self.embedding(x[:, :, 0]), self.embedding2(x[:, :, 1])], dim=-1
283+
)
284+
285+
# Pass through each S4Block
286+
for block in self.blocks:
287+
x = block(x)
288+
289+
# Final projection: (batch_size, seq_len, label_dim)
290+
return self.linear(x)
291+
292+
def step(self, x: torch.Tensor) -> torch.Tensor:
293+
# Embedding for the current step
294+
if self.second_embedding:
295+
x = torch.cat(
296+
[self.embedding(x[:, 0].long()), self.embedding2(x[:, 1].long())],
297+
dim=-1,
298+
)
299+
else:
300+
x = self.embedding(x)
301+
302+
for block in self.blocks:
303+
x = block.step(x)
304+
305+
# Final projection for the step
306+
return self.linear(x)

0 commit comments

Comments
 (0)