Skip to content

Commit de80996

Browse files
Mamba2 state passing
Signed-off-by: Nune <ntadevosyan@nvidia.com>
1 parent b537ded commit de80996

File tree

1 file changed

+54
-9
lines changed

1 file changed

+54
-9
lines changed

mamba_ssm/modules/mamba2.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def __init__(
9191
self.chunk_size = chunk_size
9292
self.use_mem_eff_path = use_mem_eff_path
9393
self.layer_idx = layer_idx
94+
self.ssm_state = None
9495

9596
# Order: [z, x, B, C, dt]
9697
d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
@@ -159,14 +160,19 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param
159160
(in case batch is small).
160161
Returns: same shape as u
161162
"""
162-
import pdb; pdb.set_trace()
163163
seqlen_og = seqlen
164+
#import pdb; pdb.set_trace()
165+
cache_device = self.in_proj.weight.device
166+
cache_dtype = self.in_proj.weight.dtype
164167
if seqlen is None:
165168
batch, seqlen, dim = u.shape
166169
else:
167170
batch_seqlen, dim = u.shape
168171
batch = batch_seqlen // seqlen
169172

173+
should_cache_states = inference_params is None and cu_seqlens is None
174+
cached_state = self._maybe_get_cached_state(batch, cache_device, cache_dtype) if should_cache_states else None
175+
170176
conv_state, ssm_state = None, None
171177
if inference_params is not None:
172178
inference_batch = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch
@@ -183,6 +189,7 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param
183189
A = -torch.exp(self.A_log.float()) # (nheads) or (d_inner, d_state)
184190
dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
185191
if self.use_mem_eff_path and inference_params is None:
192+
return_final_states = should_cache_states
186193
out = mamba_split_conv1d_scan_combined(
187194
zxbcdt,
188195
rearrange(self.conv1d.weight, "d 1 w -> d w"),
@@ -200,8 +207,14 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param
200207
headdim=None if self.D_has_hdim else self.headdim,
201208
ngroups=self.ngroups,
202209
norm_before_gate=self.norm_before_gate,
210+
initial_states=cached_state,
211+
return_final_states=return_final_states,
203212
**dt_limit_kwargs,
204213
)
214+
if return_final_states:
215+
out, final_states = out
216+
if should_cache_states and seqlen > 0:
217+
self._update_cache_state(final_states)
205218
if seqlen_og is not None:
206219
out = rearrange(out, "b l d -> (b l) d")
207220
if self.process_group is not None:
@@ -242,6 +255,9 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param
242255
seq_idx=seq_idx,
243256
).transpose(1, 2)
244257
x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
258+
return_varlen_states = cu_seqlens is not None and inference_params is not None
259+
initial_states = ssm_state if ssm_state is not None else cached_state
260+
return_final_states = (ssm_state is not None) or should_cache_states
245261
y = mamba_chunk_scan_combined(
246262
rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
247263
dt,
@@ -256,16 +272,24 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param
256272
seq_idx=seq_idx,
257273
cu_seqlens=cu_seqlens,
258274
**dt_limit_kwargs,
259-
return_final_states=ssm_state is not None,
260-
return_varlen_states=cu_seqlens is not None and inference_params is not None,
275+
return_final_states=return_final_states,
276+
return_varlen_states=return_varlen_states,
277+
initial_states=initial_states,
261278
)
262-
if ssm_state is not None:
263-
y, last_state, *rest = y
264-
if cu_seqlens is None:
265-
ssm_state.copy_(last_state)
266-
else:
279+
if return_final_states:
280+
if return_varlen_states:
281+
y, last_state, *rest = y
267282
varlen_states = rest[0]
268-
ssm_state.copy_(varlen_states)
283+
else:
284+
y, last_state = y
285+
varlen_states = None
286+
if ssm_state is not None:
287+
if cu_seqlens is None:
288+
ssm_state.copy_(last_state)
289+
else:
290+
ssm_state.copy_(varlen_states)
291+
if should_cache_states and cu_seqlens is None and seqlen > 0:
292+
self._update_cache_state(last_state)
269293
y = rearrange(y, "b l h p -> b l (h p)")
270294
if self.rmsnorm:
271295
y = self.norm(y, z)
@@ -382,3 +406,24 @@ def _get_states_from_cache(self, inference_params, batch_size, initialize_states
382406
conv_state.zero_()
383407
ssm_state.zero_()
384408
return conv_state, ssm_state
409+
410+
def reset_cache(self):
411+
self.ssm_state = None
412+
413+
def _maybe_get_cached_state(self, batch_size, device, dtype):
414+
if self.ssm_state is None:
415+
return None
416+
if self.ssm_state.shape[0] != batch_size or self.ssm_state.device != device or self.ssm_state.dtype != dtype:
417+
self.ssm_state = None
418+
return None
419+
return self.ssm_state
420+
421+
def _update_cache_state(self, new_state):
422+
if new_state is None:
423+
self.ssm_state = None
424+
return
425+
cache_state = new_state.detach()
426+
target_dtype = self.in_proj.weight.dtype
427+
if cache_state.dtype != target_dtype:
428+
cache_state = cache_state.to(dtype=target_dtype)
429+
self.ssm_state = cache_state.contiguous()

0 commit comments

Comments
 (0)