|
| 1 | +import math |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | +from einops import rearrange, repeat |
| 6 | +from mamba_ssm.ops.selective_scan_interface import selective_scan_fn |
| 7 | + |
| 8 | + |
| 9 | +class MambaRecurrence(nn.Module): |
| 10 | + """ |
| 11 | + Implements the Mamba recurrence layer for sequence modeling. |
| 12 | +
|
| 13 | + Args: |
| 14 | + d_model (int): Dimension of the model (number of features). |
| 15 | + d_state (int): Dimension of the state space. Defaults to 16. |
| 16 | + dt_rank (int or str): Rank for the time-step parameterization. Defaults to |
| 17 | + "auto". |
| 18 | + dt_min (float): Minimum value for time-steps. Defaults to 0.001. |
| 19 | + dt_max (float): Maximum value for time-steps. Defaults to 0.1. |
| 20 | + dt_init (str): Initialization method for dt. Options are "constant" or "random". |
| 21 | + Defaults to "random". |
| 22 | + dt_scale (float): Scale factor for dt initialization. Defaults to 1.0. |
| 23 | + dt_init_floor (float): Floor value for initializing time-steps. Defaults to |
| 24 | + 1e-4. |
| 25 | + device (torch.device, optional): Device to run the computations on. If None, |
| 26 | + uses CUDA if available. |
| 27 | +
|
| 28 | + Forward Args: |
| 29 | + hidden_states (Tensor): Input tensor of shape (batch_size, sequence_length, |
| 30 | + d_model). |
| 31 | +
|
| 32 | + Returns: |
| 33 | + Tensor: Output tensor of the same shape as input. |
| 34 | + """ |
| 35 | + |
| 36 | + def __init__( |
| 37 | + self, |
| 38 | + d_model, |
| 39 | + d_state=16, |
| 40 | + dt_rank="auto", |
| 41 | + dt_min=0.001, |
| 42 | + dt_max=0.1, |
| 43 | + dt_init="random", |
| 44 | + dt_scale=1.0, |
| 45 | + dt_init_floor=1e-4, |
| 46 | + device=None, |
| 47 | + ): |
| 48 | + super().__init__() |
| 49 | + |
| 50 | + if device is None: |
| 51 | + if torch.cuda.is_available(): |
| 52 | + self.device = torch.device("cuda") |
| 53 | + else: |
| 54 | + self.device = torch.device("cpu") |
| 55 | + else: |
| 56 | + self.device = device |
| 57 | + self.d_model = d_model |
| 58 | + self.d_state = d_state |
| 59 | + self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank |
| 60 | + |
| 61 | + self.x_proj = nn.Linear( |
| 62 | + self.d_model, |
| 63 | + self.dt_rank + self.d_state * 2, |
| 64 | + bias=False, |
| 65 | + device=self.device, |
| 66 | + ) |
| 67 | + self.dt_proj = nn.Linear( |
| 68 | + self.dt_rank, self.d_model, bias=True, device=self.device |
| 69 | + ) |
| 70 | + |
| 71 | + # Initialize special dt projection to preserve variance at initialization |
| 72 | + dt_init_std = self.dt_rank**-0.5 * dt_scale |
| 73 | + if dt_init == "constant": |
| 74 | + nn.init.constant_(self.dt_proj.weight, dt_init_std) |
| 75 | + elif dt_init == "random": |
| 76 | + nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) |
| 77 | + else: |
| 78 | + raise NotImplementedError |
| 79 | + |
| 80 | + # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max |
| 81 | + dt = torch.exp( |
| 82 | + torch.rand(self.d_model, device=self.device) |
| 83 | + * (math.log(dt_max) - math.log(dt_min)) |
| 84 | + + math.log(dt_min) |
| 85 | + ).clamp(min=dt_init_floor) |
| 86 | + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 |
| 87 | + inv_dt = dt + torch.log(-torch.expm1(-dt)) |
| 88 | + with torch.no_grad(): |
| 89 | + self.dt_proj.bias.copy_(inv_dt) |
| 90 | + # Our initialization would set all Linear.bias to zero, need to mark this |
| 91 | + # one as _no_reinit |
| 92 | + self.dt_proj.bias._no_reinit = True |
| 93 | + |
| 94 | + # S4D real initialization |
| 95 | + A = repeat( |
| 96 | + torch.arange(1, self.d_state + 1, dtype=torch.float32, device=self.device), |
| 97 | + "n -> d n", |
| 98 | + d=self.d_model, |
| 99 | + ).contiguous() |
| 100 | + A_log = torch.log(A) # Keep A_log in fp32 |
| 101 | + self.A_log = nn.Parameter(A_log) |
| 102 | + |
| 103 | + # D "skip" parameter |
| 104 | + self.D = nn.Parameter( |
| 105 | + torch.ones(self.d_model, device=self.device) |
| 106 | + ) # Keep in fp32 |
| 107 | + self.D._no_weight_decay = True |
| 108 | + |
| 109 | + def forward(self, hidden_states): |
| 110 | + """ |
| 111 | + hidden_states: (B, L, D) |
| 112 | + Returns: same shape as hidden_states |
| 113 | + """ |
| 114 | + batch, seqlen, dim = hidden_states.shape |
| 115 | + A = -torch.exp(self.A_log.float()) # (d_inner, d_state) |
| 116 | + x = rearrange(hidden_states, "b l d -> b d l") |
| 117 | + x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) |
| 118 | + dt, B, C = torch.split( |
| 119 | + x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1 |
| 120 | + ) |
| 121 | + dt = self.dt_proj.weight @ dt.t() |
| 122 | + dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) |
| 123 | + B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() |
| 124 | + C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() |
| 125 | + y = selective_scan_fn( |
| 126 | + x, |
| 127 | + dt, |
| 128 | + A, |
| 129 | + B, |
| 130 | + C, |
| 131 | + self.D.float(), |
| 132 | + z=None, |
| 133 | + delta_bias=self.dt_proj.bias.float(), |
| 134 | + delta_softplus=True, |
| 135 | + return_last_state=False, |
| 136 | + ) |
| 137 | + return rearrange(y, "b d l -> b l d") |
| 138 | + |
| 139 | + |
| 140 | +class S6(nn.Module): |
| 141 | + def __init__( |
| 142 | + self, |
| 143 | + num_blocks: int, |
| 144 | + data_dim: int, |
| 145 | + model_dim: int, |
| 146 | + label_dim: int, |
| 147 | + dropout_rate: float = 0.1, |
| 148 | + second_embedding: bool = False, |
| 149 | + use_glu: bool = False, |
| 150 | + d_state: int = 16, |
| 151 | + dt_rank: str | int = "auto", |
| 152 | + ): |
| 153 | + """ |
| 154 | + d_state: state dimension of the SSM (default 16). |
| 155 | + dt_rank: rank for dt parameterization ("auto" uses ceil(model_dim/16)). |
| 156 | + """ |
| 157 | + super().__init__() |
| 158 | + self.second_embedding = second_embedding |
| 159 | + |
| 160 | + emb_dim = model_dim // 2 if second_embedding else model_dim |
| 161 | + self.embedding = nn.Embedding(data_dim, emb_dim) |
| 162 | + if second_embedding: |
| 163 | + self.embedding2 = nn.Embedding(data_dim, emb_dim) |
| 164 | + |
| 165 | + self.blocks = nn.ModuleList( |
| 166 | + [ |
| 167 | + MambaRecurrence(model_dim, d_state=d_state, dt_rank=dt_rank) |
| 168 | + for _ in range(num_blocks) |
| 169 | + ] |
| 170 | + ) |
| 171 | + self.norms = nn.ModuleList([nn.LayerNorm(model_dim) for _ in range(num_blocks)]) |
| 172 | + |
| 173 | + self.dropout = nn.Dropout(dropout_rate) |
| 174 | + self.linear = nn.Linear(model_dim, label_dim) |
| 175 | + |
| 176 | + self.use_glu = use_glu |
| 177 | + if use_glu: |
| 178 | + self.glu_projs = nn.ModuleList( |
| 179 | + [nn.Linear(model_dim, 2 * model_dim) for _ in range(num_blocks)] |
| 180 | + ) |
| 181 | + else: |
| 182 | + self.glu_projs = nn.ModuleList( |
| 183 | + [nn.Linear(model_dim, model_dim) for _ in range(num_blocks)] |
| 184 | + ) |
| 185 | + self.act = nn.GLU() |
| 186 | + |
| 187 | + def mask_grads(self): |
| 188 | + pass |
| 189 | + |
| 190 | + def _embed(self, x: torch.Tensor) -> torch.Tensor: |
| 191 | + if not self.second_embedding: |
| 192 | + # x: (B, L) |
| 193 | + return self.embedding(x) # -> (B, L, model_dim) |
| 194 | + else: |
| 195 | + # x: (B, L, 2) |
| 196 | + return torch.cat( |
| 197 | + [self.embedding(x[:, :, 0]), self.embedding2(x[:, :, 1])], dim=-1 |
| 198 | + ) # -> (B, L, model_dim) |
| 199 | + |
| 200 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 201 | + x = self._embed(x) # (B, L, D=model_dim) |
| 202 | + for i, (blk, ln) in enumerate(zip(self.blocks, self.norms)): |
| 203 | + residual = x |
| 204 | + x = blk(x) |
| 205 | + h = self.glu_projs[i](x) |
| 206 | + h = self.act(h) if self.use_glu else torch.tanh(h) |
| 207 | + x = residual + x + h |
| 208 | + x = self.dropout(ln(x)) |
| 209 | + return self.linear(x) # (B, L, label_dim) |
0 commit comments