@@ -91,6 +91,7 @@ def __init__(
9191 self .chunk_size = chunk_size
9292 self .use_mem_eff_path = use_mem_eff_path
9393 self .layer_idx = layer_idx
94+ self .ssm_state = None
9495
9596 # Order: [z, x, B, C, dt]
9697 d_in_proj = 2 * self .d_inner + 2 * self .ngroups * self .d_state + self .nheads
@@ -159,14 +160,19 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param
159160 (in case batch is small).
160161 Returns: same shape as u
161162 """
162- import pdb ; pdb .set_trace ()
163163 seqlen_og = seqlen
164+ #import pdb; pdb.set_trace()
165+ cache_device = self .in_proj .weight .device
166+ cache_dtype = self .in_proj .weight .dtype
164167 if seqlen is None :
165168 batch , seqlen , dim = u .shape
166169 else :
167170 batch_seqlen , dim = u .shape
168171 batch = batch_seqlen // seqlen
169172
173+ should_cache_states = inference_params is None and cu_seqlens is None
174+ cached_state = self ._maybe_get_cached_state (batch , cache_device , cache_dtype ) if should_cache_states else None
175+
170176 conv_state , ssm_state = None , None
171177 if inference_params is not None :
172178 inference_batch = cu_seqlens .shape [0 ] - 1 if cu_seqlens is not None else batch
@@ -183,6 +189,7 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param
183189 A = - torch .exp (self .A_log .float ()) # (nheads) or (d_inner, d_state)
184190 dt_limit_kwargs = {} if self .dt_limit == (0.0 , float ("inf" )) else dict (dt_limit = self .dt_limit )
185191 if self .use_mem_eff_path and inference_params is None :
192+ return_final_states = should_cache_states
186193 out = mamba_split_conv1d_scan_combined (
187194 zxbcdt ,
188195 rearrange (self .conv1d .weight , "d 1 w -> d w" ),
@@ -200,8 +207,14 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param
200207 headdim = None if self .D_has_hdim else self .headdim ,
201208 ngroups = self .ngroups ,
202209 norm_before_gate = self .norm_before_gate ,
210+ initial_states = cached_state ,
211+ return_final_states = return_final_states ,
203212 ** dt_limit_kwargs ,
204213 )
214+ if return_final_states :
215+ out , final_states = out
216+ if should_cache_states and seqlen > 0 :
217+ self ._update_cache_state (final_states )
205218 if seqlen_og is not None :
206219 out = rearrange (out , "b l d -> (b l) d" )
207220 if self .process_group is not None :
@@ -242,6 +255,9 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param
242255 seq_idx = seq_idx ,
243256 ).transpose (1 , 2 )
244257 x , B , C = torch .split (xBC , [self .d_ssm , self .ngroups * self .d_state , self .ngroups * self .d_state ], dim = - 1 )
258+ return_varlen_states = cu_seqlens is not None and inference_params is not None
259+ initial_states = ssm_state if ssm_state is not None else cached_state
260+ return_final_states = (ssm_state is not None ) or should_cache_states
245261 y = mamba_chunk_scan_combined (
246262 rearrange (x , "b l (h p) -> b l h p" , p = self .headdim ),
247263 dt ,
@@ -256,16 +272,24 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param
256272 seq_idx = seq_idx ,
257273 cu_seqlens = cu_seqlens ,
258274 ** dt_limit_kwargs ,
259- return_final_states = ssm_state is not None ,
260- return_varlen_states = cu_seqlens is not None and inference_params is not None ,
275+ return_final_states = return_final_states ,
276+ return_varlen_states = return_varlen_states ,
277+ initial_states = initial_states ,
261278 )
262- if ssm_state is not None :
263- y , last_state , * rest = y
264- if cu_seqlens is None :
265- ssm_state .copy_ (last_state )
266- else :
279+ if return_final_states :
280+ if return_varlen_states :
281+ y , last_state , * rest = y
267282 varlen_states = rest [0 ]
268- ssm_state .copy_ (varlen_states )
283+ else :
284+ y , last_state = y
285+ varlen_states = None
286+ if ssm_state is not None :
287+ if cu_seqlens is None :
288+ ssm_state .copy_ (last_state )
289+ else :
290+ ssm_state .copy_ (varlen_states )
291+ if should_cache_states and cu_seqlens is None and seqlen > 0 :
292+ self ._update_cache_state (last_state )
269293 y = rearrange (y , "b l h p -> b l (h p)" )
270294 if self .rmsnorm :
271295 y = self .norm (y , z )
@@ -382,3 +406,24 @@ def _get_states_from_cache(self, inference_params, batch_size, initialize_states
382406 conv_state .zero_ ()
383407 ssm_state .zero_ ()
384408 return conv_state , ssm_state
409+
410+ def reset_cache (self ):
411+ self .ssm_state = None
412+
413+ def _maybe_get_cached_state (self , batch_size , device , dtype ):
414+ if self .ssm_state is None :
415+ return None
416+ if self .ssm_state .shape [0 ] != batch_size or self .ssm_state .device != device or self .ssm_state .dtype != dtype :
417+ self .ssm_state = None
418+ return None
419+ return self .ssm_state
420+
421+ def _update_cache_state (self , new_state ):
422+ if new_state is None :
423+ self .ssm_state = None
424+ return
425+ cache_state = new_state .detach ()
426+ target_dtype = self .in_proj .weight .dtype
427+ if cache_state .dtype != target_dtype :
428+ cache_state = cache_state .to (dtype = target_dtype )
429+ self .ssm_state = cache_state .contiguous ()
0 commit comments