Skip to content

Commit d4d1827

Browse files
Added S4D
1 parent 254bb4b commit d4d1827

File tree

2 files changed

+88
-7
lines changed

2 files changed

+88
-7
lines changed

models/s6.py renamed to models/ssm.py

Lines changed: 84 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,78 @@
66
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
77

88

9+
class S4Recurrence(nn.Module):
10+
"""
11+
Real S4D-style recurrence compatible with selective_scan_fn.
12+
"""
13+
14+
def __init__(
15+
self,
16+
d_model,
17+
d_state=16,
18+
dt_rank="auto", # kept for API symmetry; unused here
19+
dt_min=0.001,
20+
dt_max=0.1,
21+
dt_init_floor=1e-4,
22+
device=None,
23+
):
24+
super().__init__()
25+
if device is None:
26+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27+
else:
28+
self.device = device
29+
30+
self.d_model = d_model
31+
self.d_state = d_state
32+
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
33+
34+
dt = torch.exp(
35+
torch.rand(self.d_model, device=self.device)
36+
* (math.log(dt_max) - math.log(dt_min))
37+
+ math.log(dt_min)
38+
).clamp(min=dt_init_floor)
39+
self.log_dt = nn.Parameter(torch.log(dt))
40+
41+
# S4D real initialization
42+
A = repeat(
43+
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=self.device),
44+
"n -> d n",
45+
d=self.d_model,
46+
).contiguous()
47+
self.A_log = nn.Parameter(torch.log(A))
48+
49+
self.B = nn.Parameter(
50+
torch.empty(self.d_model, self.d_state, device=self.device)
51+
)
52+
self.C = nn.Parameter(
53+
torch.empty(self.d_model, self.d_state, device=self.device)
54+
)
55+
nn.init.xavier_normal_(self.B)
56+
nn.init.xavier_normal_(self.C)
57+
self.D = nn.Parameter(torch.ones(self.d_model, device=self.device))
58+
59+
def forward(self, hidden_states):
60+
# x: (B, L, D) -> (B, D, L)
61+
x = rearrange(hidden_states, "b l d -> b d l").contiguous()
62+
63+
A = -torch.exp(self.A_log.float())
64+
dt = self.log_dt.exp()[None, :, None].expand(
65+
x.shape[0], self.d_model, x.shape[2]
66+
)
67+
y = selective_scan_fn(
68+
x,
69+
dt,
70+
A,
71+
self.B.float(),
72+
self.C.float(),
73+
self.D.float(),
74+
z=None,
75+
delta_softplus=False,
76+
return_last_state=False,
77+
)
78+
return rearrange(y, "b d l -> b l d")
79+
80+
981
class MambaRecurrence(nn.Module):
1082
"""
1183
Implements the Mamba recurrence layer for sequence modeling.
@@ -113,13 +185,13 @@ def forward(self, hidden_states):
113185
"""
114186
batch, seqlen, dim = hidden_states.shape
115187
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
116-
x = rearrange(hidden_states, "b l d -> b d l")
188+
x = rearrange(hidden_states, "b l d -> b d l").contiguous()
117189
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
118190
dt, B, C = torch.split(
119191
x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1
120192
)
121193
dt = self.dt_proj.weight @ dt.t()
122-
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
194+
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen).contiguous()
123195
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
124196
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
125197
y = selective_scan_fn(
@@ -137,9 +209,10 @@ def forward(self, hidden_states):
137209
return rearrange(y, "b d l -> b l d")
138210

139211

140-
class S6(nn.Module):
212+
class SSM(nn.Module):
141213
def __init__(
142214
self,
215+
recurrence_type: str,
143216
num_blocks: int,
144217
data_dim: int,
145218
model_dim: int,
@@ -162,9 +235,16 @@ def __init__(
162235
if second_embedding:
163236
self.embedding2 = nn.Embedding(data_dim, emb_dim)
164237

238+
if recurrence_type == "S6":
239+
recurrence_cls = MambaRecurrence
240+
elif recurrence_type == "S4":
241+
recurrence_cls = S4Recurrence
242+
else:
243+
raise ValueError(f"Unknown recurrence type: {recurrence_type}")
244+
165245
self.blocks = nn.ModuleList(
166246
[
167-
MambaRecurrence(model_dim, d_state=d_state, dt_rank=dt_rank)
247+
recurrence_cls(model_dim, d_state=d_state, dt_rank=dt_rank)
168248
for _ in range(num_blocks)
169249
]
170250
)

train.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -405,10 +405,11 @@ def train_dataloader_multilength():
405405
use_glu=use_glu,
406406
second_embedding=second_embedding,
407407
)
408-
elif model_name == "S6":
409-
from models.s6 import S6
408+
elif model_name in ["S4", "S6"]:
409+
from models.ssm import SSM
410410

411-
model = S6(
411+
model = SSM(
412+
recurrence_type=model_name,
412413
num_blocks=num_blocks,
413414
model_dim=model_dim,
414415
data_dim=data_dim,

0 commit comments

Comments
 (0)