66from mamba_ssm .ops .selective_scan_interface import selective_scan_fn
77
88
9+ class S4Recurrence (nn .Module ):
10+ """
11+ Real S4D-style recurrence compatible with selective_scan_fn.
12+ """
13+
14+ def __init__ (
15+ self ,
16+ d_model ,
17+ d_state = 16 ,
18+ dt_rank = "auto" , # kept for API symmetry; unused here
19+ dt_min = 0.001 ,
20+ dt_max = 0.1 ,
21+ dt_init_floor = 1e-4 ,
22+ device = None ,
23+ ):
24+ super ().__init__ ()
25+ if device is None :
26+ self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
27+ else :
28+ self .device = device
29+
30+ self .d_model = d_model
31+ self .d_state = d_state
32+ self .dt_rank = math .ceil (self .d_model / 16 ) if dt_rank == "auto" else dt_rank
33+
34+ dt = torch .exp (
35+ torch .rand (self .d_model , device = self .device )
36+ * (math .log (dt_max ) - math .log (dt_min ))
37+ + math .log (dt_min )
38+ ).clamp (min = dt_init_floor )
39+ self .log_dt = nn .Parameter (torch .log (dt ))
40+
41+ # S4D real initialization
42+ A = repeat (
43+ torch .arange (1 , self .d_state + 1 , dtype = torch .float32 , device = self .device ),
44+ "n -> d n" ,
45+ d = self .d_model ,
46+ ).contiguous ()
47+ self .A_log = nn .Parameter (torch .log (A ))
48+
49+ self .B = nn .Parameter (
50+ torch .empty (self .d_model , self .d_state , device = self .device )
51+ )
52+ self .C = nn .Parameter (
53+ torch .empty (self .d_model , self .d_state , device = self .device )
54+ )
55+ nn .init .xavier_normal_ (self .B )
56+ nn .init .xavier_normal_ (self .C )
57+ self .D = nn .Parameter (torch .ones (self .d_model , device = self .device ))
58+
59+ def forward (self , hidden_states ):
60+ # x: (B, L, D) -> (B, D, L)
61+ x = rearrange (hidden_states , "b l d -> b d l" ).contiguous ()
62+
63+ A = - torch .exp (self .A_log .float ())
64+ dt = self .log_dt .exp ()[None , :, None ].expand (
65+ x .shape [0 ], self .d_model , x .shape [2 ]
66+ )
67+ y = selective_scan_fn (
68+ x ,
69+ dt ,
70+ A ,
71+ self .B .float (),
72+ self .C .float (),
73+ self .D .float (),
74+ z = None ,
75+ delta_softplus = False ,
76+ return_last_state = False ,
77+ )
78+ return rearrange (y , "b d l -> b l d" )
79+
80+
981class MambaRecurrence (nn .Module ):
1082 """
1183 Implements the Mamba recurrence layer for sequence modeling.
@@ -113,13 +185,13 @@ def forward(self, hidden_states):
113185 """
114186 batch , seqlen , dim = hidden_states .shape
115187 A = - torch .exp (self .A_log .float ()) # (d_inner, d_state)
116- x = rearrange (hidden_states , "b l d -> b d l" )
188+ x = rearrange (hidden_states , "b l d -> b d l" ). contiguous ()
117189 x_dbl = self .x_proj (rearrange (x , "b d l -> (b l) d" )) # (bl d)
118190 dt , B , C = torch .split (
119191 x_dbl , [self .dt_rank , self .d_state , self .d_state ], dim = - 1
120192 )
121193 dt = self .dt_proj .weight @ dt .t ()
122- dt = rearrange (dt , "d (b l) -> b d l" , l = seqlen )
194+ dt = rearrange (dt , "d (b l) -> b d l" , l = seqlen ). contiguous ()
123195 B = rearrange (B , "(b l) dstate -> b dstate l" , l = seqlen ).contiguous ()
124196 C = rearrange (C , "(b l) dstate -> b dstate l" , l = seqlen ).contiguous ()
125197 y = selective_scan_fn (
@@ -137,9 +209,10 @@ def forward(self, hidden_states):
137209 return rearrange (y , "b d l -> b l d" )
138210
139211
140- class S6 (nn .Module ):
212+ class SSM (nn .Module ):
141213 def __init__ (
142214 self ,
215+ recurrence_type : str ,
143216 num_blocks : int ,
144217 data_dim : int ,
145218 model_dim : int ,
@@ -162,9 +235,16 @@ def __init__(
162235 if second_embedding :
163236 self .embedding2 = nn .Embedding (data_dim , emb_dim )
164237
238+ if recurrence_type == "S6" :
239+ recurrence_cls = MambaRecurrence
240+ elif recurrence_type == "S4" :
241+ recurrence_cls = S4Recurrence
242+ else :
243+ raise ValueError (f"Unknown recurrence type: { recurrence_type } " )
244+
165245 self .blocks = nn .ModuleList (
166246 [
167- MambaRecurrence (model_dim , d_state = d_state , dt_rank = dt_rank )
247+ recurrence_cls (model_dim , d_state = d_state , dt_rank = dt_rank )
168248 for _ in range (num_blocks )
169249 ]
170250 )
0 commit comments