Skip to content

Commit 6cf5fc0

Browse files
lingyiyangBenjamin-Walker
authored andcommitted
Add S6 model aka MambaRecurrence
1 parent 420dbec commit 6cf5fc0

File tree

2 files changed

+226
-0
lines changed

2 files changed

+226
-0
lines changed

models/s6.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
import math
2+
3+
import torch
4+
import torch.nn as nn
5+
from einops import rearrange, repeat
6+
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
7+
8+
9+
class MambaRecurrence(nn.Module):
10+
"""
11+
Implements the Mamba recurrence layer for sequence modeling.
12+
13+
Args:
14+
d_model (int): Dimension of the model (number of features).
15+
d_state (int): Dimension of the state space. Defaults to 16.
16+
dt_rank (int or str): Rank for the time-step parameterization. Defaults to
17+
"auto".
18+
dt_min (float): Minimum value for time-steps. Defaults to 0.001.
19+
dt_max (float): Maximum value for time-steps. Defaults to 0.1.
20+
dt_init (str): Initialization method for dt. Options are "constant" or "random".
21+
Defaults to "random".
22+
dt_scale (float): Scale factor for dt initialization. Defaults to 1.0.
23+
dt_init_floor (float): Floor value for initializing time-steps. Defaults to
24+
1e-4.
25+
device (torch.device, optional): Device to run the computations on. If None,
26+
uses CUDA if available.
27+
28+
Forward Args:
29+
hidden_states (Tensor): Input tensor of shape (batch_size, sequence_length,
30+
d_model).
31+
32+
Returns:
33+
Tensor: Output tensor of the same shape as input.
34+
"""
35+
36+
def __init__(
37+
self,
38+
d_model,
39+
d_state=16,
40+
dt_rank="auto",
41+
dt_min=0.001,
42+
dt_max=0.1,
43+
dt_init="random",
44+
dt_scale=1.0,
45+
dt_init_floor=1e-4,
46+
device=None,
47+
):
48+
super().__init__()
49+
50+
if device is None:
51+
if torch.cuda.is_available():
52+
self.device = torch.device("cuda")
53+
else:
54+
self.device = torch.device("cpu")
55+
else:
56+
self.device = device
57+
self.d_model = d_model
58+
self.d_state = d_state
59+
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
60+
61+
self.x_proj = nn.Linear(
62+
self.d_model,
63+
self.dt_rank + self.d_state * 2,
64+
bias=False,
65+
device=self.device,
66+
)
67+
self.dt_proj = nn.Linear(
68+
self.dt_rank, self.d_model, bias=True, device=self.device
69+
)
70+
71+
# Initialize special dt projection to preserve variance at initialization
72+
dt_init_std = self.dt_rank**-0.5 * dt_scale
73+
if dt_init == "constant":
74+
nn.init.constant_(self.dt_proj.weight, dt_init_std)
75+
elif dt_init == "random":
76+
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
77+
else:
78+
raise NotImplementedError
79+
80+
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
81+
dt = torch.exp(
82+
torch.rand(self.d_model, device=self.device)
83+
* (math.log(dt_max) - math.log(dt_min))
84+
+ math.log(dt_min)
85+
).clamp(min=dt_init_floor)
86+
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
87+
inv_dt = dt + torch.log(-torch.expm1(-dt))
88+
with torch.no_grad():
89+
self.dt_proj.bias.copy_(inv_dt)
90+
# Our initialization would set all Linear.bias to zero, need to mark this
91+
# one as _no_reinit
92+
self.dt_proj.bias._no_reinit = True
93+
94+
# S4D real initialization
95+
A = repeat(
96+
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=self.device),
97+
"n -> d n",
98+
d=self.d_model,
99+
).contiguous()
100+
A_log = torch.log(A) # Keep A_log in fp32
101+
self.A_log = nn.Parameter(A_log)
102+
103+
# D "skip" parameter
104+
self.D = nn.Parameter(
105+
torch.ones(self.d_model, device=self.device)
106+
) # Keep in fp32
107+
self.D._no_weight_decay = True
108+
109+
def forward(self, hidden_states):
110+
"""
111+
hidden_states: (B, L, D)
112+
Returns: same shape as hidden_states
113+
"""
114+
batch, seqlen, dim = hidden_states.shape
115+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
116+
x = rearrange(hidden_states, "b l d -> b d l")
117+
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
118+
dt, B, C = torch.split(
119+
x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1
120+
)
121+
dt = self.dt_proj.weight @ dt.t()
122+
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
123+
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
124+
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
125+
y = selective_scan_fn(
126+
x,
127+
dt,
128+
A,
129+
B,
130+
C,
131+
self.D.float(),
132+
z=None,
133+
delta_bias=self.dt_proj.bias.float(),
134+
delta_softplus=True,
135+
return_last_state=False,
136+
)
137+
return rearrange(y, "b d l -> b l d")
138+
139+
140+
class S6(nn.Module):
141+
def __init__(
142+
self,
143+
num_blocks: int,
144+
data_dim: int,
145+
model_dim: int,
146+
label_dim: int,
147+
dropout_rate: float = 0.1,
148+
second_embedding: bool = False,
149+
use_glu: bool = False,
150+
d_state: int = 16,
151+
dt_rank: str | int = "auto",
152+
):
153+
"""
154+
d_state: state dimension of the SSM (default 16).
155+
dt_rank: rank for dt parameterization ("auto" uses ceil(model_dim/16)).
156+
"""
157+
super().__init__()
158+
self.second_embedding = second_embedding
159+
160+
emb_dim = model_dim // 2 if second_embedding else model_dim
161+
self.embedding = nn.Embedding(data_dim, emb_dim)
162+
if second_embedding:
163+
self.embedding2 = nn.Embedding(data_dim, emb_dim)
164+
165+
self.blocks = nn.ModuleList(
166+
[
167+
MambaRecurrence(model_dim, d_state=d_state, dt_rank=dt_rank)
168+
for _ in range(num_blocks)
169+
]
170+
)
171+
self.norms = nn.ModuleList([nn.LayerNorm(model_dim) for _ in range(num_blocks)])
172+
173+
self.dropout = nn.Dropout(dropout_rate)
174+
self.linear = nn.Linear(model_dim, label_dim)
175+
176+
self.use_glu = use_glu
177+
if use_glu:
178+
self.glu_projs = nn.ModuleList(
179+
[nn.Linear(model_dim, 2 * model_dim) for _ in range(num_blocks)]
180+
)
181+
else:
182+
self.glu_projs = nn.ModuleList(
183+
[nn.Linear(model_dim, model_dim) for _ in range(num_blocks)]
184+
)
185+
self.act = nn.GLU()
186+
187+
def mask_grads(self):
188+
pass
189+
190+
def _embed(self, x: torch.Tensor) -> torch.Tensor:
191+
if not self.second_embedding:
192+
# x: (B, L)
193+
return self.embedding(x) # -> (B, L, model_dim)
194+
else:
195+
# x: (B, L, 2)
196+
return torch.cat(
197+
[self.embedding(x[:, :, 0]), self.embedding2(x[:, :, 1])], dim=-1
198+
) # -> (B, L, model_dim)
199+
200+
def forward(self, x: torch.Tensor) -> torch.Tensor:
201+
x = self._embed(x) # (B, L, D=model_dim)
202+
for i, (blk, ln) in enumerate(zip(self.blocks, self.norms)):
203+
residual = x
204+
x = blk(x)
205+
h = self.glu_projs[i](x)
206+
h = self.act(h) if self.use_glu else torch.tanh(h)
207+
x = residual + x + h
208+
x = self.dropout(ln(x))
209+
return self.linear(x) # (B, L, label_dim)

train.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,8 @@ def run_experiment(config):
298298
slstm_at = config.get("slstm_at", [1])
299299
vf_A_norm_lambda = config.get("vf_A_norm_lambda", 0.001)
300300
rank = config.get("rank", 0)
301+
d_state = config.get("d_state", 16)
302+
dt_rank = config.get("dt_rank", "auto")
301303

302304
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
303305

@@ -399,6 +401,21 @@ def train_dataloader_multilength():
399401
use_glu=use_glu,
400402
second_embedding=second_embedding,
401403
)
404+
elif model_name == "S6":
405+
from models.s6 import S6
406+
407+
model = S6(
408+
num_blocks=num_blocks,
409+
model_dim=model_dim,
410+
data_dim=data_dim,
411+
label_dim=label_dim,
412+
dropout_rate=dropout_rate,
413+
use_glu=use_glu,
414+
second_embedding=second_embedding,
415+
d_state=d_state,
416+
dt_rank=dt_rank,
417+
)
418+
402419
elif model_name in ["deltanet", "gateddeltanet", "rwkv7", "rwkv6", "deltaproduct"]:
403420
from models.fla import StackedBlock
404421

0 commit comments

Comments
 (0)