|
| 1 | +""" |
| 2 | +Implementation of the S4 model taken from https://github.com/state-spaces/s4 |
| 3 | +""" |
| 4 | + |
| 5 | +import math |
| 6 | + |
| 7 | +import torch |
| 8 | +import torch.nn as nn |
| 9 | +import torch.nn.functional as F |
| 10 | +from einops import rearrange, repeat |
| 11 | + |
| 12 | + |
| 13 | +class DropoutNd(nn.Module): |
| 14 | + def __init__(self, p: float = 0.5, tie=True, transposed=True): |
| 15 | + """ |
| 16 | + tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d) |
| 17 | + """ |
| 18 | + super().__init__() |
| 19 | + if p < 0 or p >= 1: |
| 20 | + raise ValueError( |
| 21 | + "dropout probability has to be in [0, 1), " "but got {}".format(p) |
| 22 | + ) |
| 23 | + self.p = p |
| 24 | + self.tie = tie |
| 25 | + self.transposed = transposed |
| 26 | + self.binomial = torch.distributions.binomial.Binomial(probs=1 - self.p) |
| 27 | + |
| 28 | + def forward(self, X): |
| 29 | + """X: (batch, dim, lengths...).""" |
| 30 | + if self.training: |
| 31 | + if not self.transposed: |
| 32 | + X = rearrange(X, "b ... d -> b d ...") |
| 33 | + # binomial = torch.distributions.binomial.Binomial(probs=1-self.p) # This is incredibly slow because of CPU -> GPU copying |
| 34 | + mask_shape = X.shape[:2] + (1,) * (X.ndim - 2) if self.tie else X.shape |
| 35 | + # mask = self.binomial.sample(mask_shape) |
| 36 | + mask = torch.rand(*mask_shape, device=X.device) < 1.0 - self.p |
| 37 | + X = X * mask * (1.0 / (1 - self.p)) |
| 38 | + if not self.transposed: |
| 39 | + X = rearrange(X, "b d ... -> b ... d") |
| 40 | + return X |
| 41 | + return X |
| 42 | + |
| 43 | + |
| 44 | +class S4DKernel(nn.Module): |
| 45 | + """Generate convolution kernel from diagonal SSM parameters.""" |
| 46 | + |
| 47 | + def __init__(self, d_model, N=64, dt_min=0.001, dt_max=0.1, lr=None): |
| 48 | + super().__init__() |
| 49 | + # Generate dt |
| 50 | + H = d_model |
| 51 | + log_dt = torch.rand(H) * (math.log(dt_max) - math.log(dt_min)) + math.log( |
| 52 | + dt_min |
| 53 | + ) |
| 54 | + |
| 55 | + C = torch.randn(H, N // 2, dtype=torch.cfloat) |
| 56 | + self.C = nn.Parameter(torch.view_as_real(C)) |
| 57 | + self.register("log_dt", log_dt, lr) |
| 58 | + |
| 59 | + log_A_real = torch.log(0.5 * torch.ones(H, N // 2)) |
| 60 | + A_imag = math.pi * repeat(torch.arange(N // 2), "n -> h n", h=H) |
| 61 | + self.register("log_A_real", log_A_real, lr) |
| 62 | + self.register("A_imag", A_imag, lr) |
| 63 | + |
| 64 | + def forward(self, L): |
| 65 | + """ |
| 66 | + returns: (..., c, L) where c is number of channels (default 1) |
| 67 | + """ |
| 68 | + |
| 69 | + # Materialize parameters |
| 70 | + dt = torch.exp(self.log_dt) # (H) |
| 71 | + C = torch.view_as_complex(self.C) # (H N) |
| 72 | + A = -torch.exp(self.log_A_real) + 1j * self.A_imag # (H N) |
| 73 | + |
| 74 | + # Vandermonde multiplication |
| 75 | + dtA = A * dt.unsqueeze(-1) # (H N) |
| 76 | + K = dtA.unsqueeze(-1) * torch.arange(L, device=A.device) # (H N L) |
| 77 | + C = C * (torch.exp(dtA) - 1.0) / A |
| 78 | + K = 2 * torch.einsum("hn, hnl -> hl", C, torch.exp(K)).real |
| 79 | + |
| 80 | + return K |
| 81 | + |
| 82 | + def register(self, name, tensor, lr=None): |
| 83 | + """Register a tensor with a configurable learning rate and 0 weight decay""" |
| 84 | + |
| 85 | + if lr == 0.0: |
| 86 | + self.register_buffer(name, tensor) |
| 87 | + else: |
| 88 | + self.register_parameter(name, nn.Parameter(tensor)) |
| 89 | + |
| 90 | + optim = {"weight_decay": 0.0} |
| 91 | + if lr is not None: |
| 92 | + optim["lr"] = lr |
| 93 | + setattr(getattr(self, name), "_optim", optim) |
| 94 | + |
| 95 | + |
| 96 | +class S4D(nn.Module): |
| 97 | + def __init__( |
| 98 | + self, d_model, d_state=64, dropout=0.0, transposed=True, **kernel_args |
| 99 | + ): |
| 100 | + super().__init__() |
| 101 | + |
| 102 | + self.h = d_model |
| 103 | + self.n = d_state |
| 104 | + self.d_output = self.h |
| 105 | + self.transposed = transposed |
| 106 | + |
| 107 | + self.D = nn.Parameter(torch.randn(self.h)) |
| 108 | + |
| 109 | + # SSM Kernel |
| 110 | + self.kernel = S4DKernel(self.h, N=self.n, **kernel_args) |
| 111 | + |
| 112 | + # Pointwise |
| 113 | + self.activation = nn.GELU() |
| 114 | + # dropout_fn = nn.Dropout2d # NOTE: bugged in PyTorch 1.11 |
| 115 | + dropout_fn = DropoutNd |
| 116 | + self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity() |
| 117 | + |
| 118 | + # position-wise output transform to mix features |
| 119 | + self.output_linear = nn.Sequential( |
| 120 | + nn.Conv1d(self.h, 2 * self.h, kernel_size=1), |
| 121 | + nn.GLU(dim=-2), |
| 122 | + ) |
| 123 | + |
| 124 | + def forward(self, u, **kwargs): # absorbs return_output and transformer src mask |
| 125 | + """Input and output shape (B, H, L)""" |
| 126 | + u = u.transpose(-1, -2) |
| 127 | + L = u.size(-1) |
| 128 | + |
| 129 | + # Compute SSM Kernel |
| 130 | + k = self.kernel(L=L) # (H L) |
| 131 | + |
| 132 | + # Convolution |
| 133 | + k_f = torch.fft.rfft(k, n=2 * L) # (H L) |
| 134 | + u_f = torch.fft.rfft(u, n=2 * L) # (B H L) |
| 135 | + y = torch.fft.irfft(u_f * k_f, n=2 * L)[..., :L] # (B H L) |
| 136 | + |
| 137 | + # Compute D term in state space equation - essentially a skip connection |
| 138 | + y = y + u * self.D.unsqueeze(-1) |
| 139 | + |
| 140 | + y = self.dropout(self.activation(y)) |
| 141 | + y = self.output_linear(y) |
| 142 | + if not self.transposed: |
| 143 | + y = y.transpose(-1, -2) |
| 144 | + return y |
| 145 | + |
| 146 | + |
| 147 | +class S4Block(nn.Module): |
| 148 | + """ |
| 149 | + A single S4 block that applies: |
| 150 | + 1. S4D module |
| 151 | + 2. (Optionally) a linear layer + GLU activation, |
| 152 | + 3. Residual connection |
| 153 | + 4. Layer Normalization |
| 154 | + 5. Dropout |
| 155 | +
|
| 156 | + Args: |
| 157 | + model_dim (int): Dimensionality of the model (d_model). |
| 158 | + dropout_rate (float): Probability of an element to be zeroed in Dropout. |
| 159 | + use_glu (bool): Whether to apply a Linear -> GLU stage after the residual. |
| 160 | + """ |
| 161 | + |
| 162 | + def __init__( |
| 163 | + self, model_dim: int, dropout_rate: float = 0.1, use_glu: bool = False |
| 164 | + ): |
| 165 | + super().__init__() |
| 166 | + self.s4 = S4D(d_model=model_dim) |
| 167 | + self.norm = nn.LayerNorm(model_dim) |
| 168 | + self.drop = nn.Dropout(p=dropout_rate) |
| 169 | + |
| 170 | + self.use_glu = use_glu |
| 171 | + if self.use_glu: |
| 172 | + # The linear expands from model_dim to 2*model_dim |
| 173 | + # so that GLU can split it into two halves of model_dim each |
| 174 | + self.post_linear = nn.Linear(model_dim, 2 * model_dim) |
| 175 | + else: |
| 176 | + self.post_linear = None |
| 177 | + |
| 178 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 179 | + """ |
| 180 | + Forward pass of the S4Block. |
| 181 | +
|
| 182 | + Args: |
| 183 | + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, model_dim). |
| 184 | +
|
| 185 | + Returns: |
| 186 | + torch.Tensor: Output tensor of the same shape (batch_size, seq_len, model_dim). |
| 187 | + """ |
| 188 | + |
| 189 | + # S4 module |
| 190 | + y = self.s4(x) |
| 191 | + y = y.transpose(1, 2) # (batch_size, model_dim, seq_len) |
| 192 | + y = y + x |
| 193 | + |
| 194 | + # Optional: Linear -> GLU |
| 195 | + if self.use_glu: |
| 196 | + # shape: (batch_size, seq_len, 2 * model_dim) |
| 197 | + y_glu = self.post_linear(y) |
| 198 | + # shape: (batch_size, seq_len, model_dim) |
| 199 | + y_glu = F.glu(y_glu, dim=-1) |
| 200 | + y = y + y_glu |
| 201 | + |
| 202 | + # Layer Normalization |
| 203 | + y = self.norm(y) |
| 204 | + |
| 205 | + # Dropout |
| 206 | + y = self.drop(y) |
| 207 | + |
| 208 | + return y |
| 209 | + |
| 210 | + |
| 211 | +class StackedS4(nn.Module): |
| 212 | + """ |
| 213 | + A stack of multiple S4Blocks, preceded by an embedding layer |
| 214 | + and followed by a linear projection. |
| 215 | +
|
| 216 | + Args: |
| 217 | + num_blocks (int): Number of S4Blocks to stack. |
| 218 | + model_dim (int): Dimensionality of embeddings and S4 blocks. |
| 219 | + data_dim (int): Size of the vocabulary (if input is token IDs). |
| 220 | + label_dim (int): Output dimensionality (e.g., number of classes). |
| 221 | + dropout_rate (float): Dropout probability for each S4Block. |
| 222 | + use_glu (bool): If True, each block will include a Linear->GLU stage |
| 223 | + that preserves model_dim. |
| 224 | + second_embedding (bool): If True, the model will expect two input |
| 225 | + token IDs and use two separate embeddings. |
| 226 | + """ |
| 227 | + |
| 228 | + def __init__( |
| 229 | + self, |
| 230 | + num_blocks: int, |
| 231 | + model_dim: int, |
| 232 | + data_dim: int, |
| 233 | + label_dim: int, |
| 234 | + dropout_rate: float = 0.1, |
| 235 | + use_glu: bool = False, |
| 236 | + second_embedding: bool = False, |
| 237 | + ): |
| 238 | + super().__init__() |
| 239 | + |
| 240 | + self.second_embedding = second_embedding |
| 241 | + embedding_dim = model_dim // 2 if second_embedding else model_dim |
| 242 | + self.embedding = nn.Embedding(data_dim, embedding_dim) |
| 243 | + if second_embedding: |
| 244 | + self.embedding2 = nn.Embedding(data_dim, embedding_dim) |
| 245 | + |
| 246 | + # Create multiple S4Blocks |
| 247 | + self.blocks = nn.ModuleList( |
| 248 | + [ |
| 249 | + S4Block(model_dim=model_dim, dropout_rate=dropout_rate, use_glu=use_glu) |
| 250 | + for _ in range(num_blocks) |
| 251 | + ] |
| 252 | + ) |
| 253 | + |
| 254 | + # The final linear projection remains (model_dim -> label_dim) |
| 255 | + self.linear = nn.Linear(model_dim, label_dim) |
| 256 | + |
| 257 | + def mask_grads(self): |
| 258 | + """ |
| 259 | + This method is included for consistency with other models. |
| 260 | + """ |
| 261 | + pass |
| 262 | + |
| 263 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 264 | + """ |
| 265 | + Forward pass of StackedS4. |
| 266 | +
|
| 267 | + Args: |
| 268 | + x (torch.Tensor): Input tensor of shape (batch_size, seq_len) |
| 269 | + containing integer token IDs (if used with nn.Embedding). |
| 270 | +
|
| 271 | + Returns: |
| 272 | + torch.Tensor: Output tensor of shape (batch_size, seq_len, label_dim). |
| 273 | + If a single-vector output is desired (e.g. for classification), |
| 274 | + additional pooling or indexing may be required |
| 275 | + before the final linear layer or after its output. |
| 276 | + """ |
| 277 | + # Embedding: (batch_size, seq_len, model_dim) |
| 278 | + if not self.second_embedding: |
| 279 | + x = self.embedding(x) |
| 280 | + else: |
| 281 | + x = torch.cat( |
| 282 | + [self.embedding(x[:, :, 0]), self.embedding2(x[:, :, 1])], dim=-1 |
| 283 | + ) |
| 284 | + |
| 285 | + # Pass through each S4Block |
| 286 | + for block in self.blocks: |
| 287 | + x = block(x) |
| 288 | + |
| 289 | + # Final projection: (batch_size, seq_len, label_dim) |
| 290 | + return self.linear(x) |
| 291 | + |
| 292 | + def step(self, x: torch.Tensor) -> torch.Tensor: |
| 293 | + # Embedding for the current step |
| 294 | + if self.second_embedding: |
| 295 | + x = torch.cat( |
| 296 | + [self.embedding(x[:, 0].long()), self.embedding2(x[:, 1].long())], |
| 297 | + dim=-1, |
| 298 | + ) |
| 299 | + else: |
| 300 | + x = self.embedding(x) |
| 301 | + |
| 302 | + for block in self.blocks: |
| 303 | + x = block.step(x) |
| 304 | + |
| 305 | + # Final projection for the step |
| 306 | + return self.linear(x) |
0 commit comments