Skip to content

Commit b7d721e

Browse files
Cache handling for Mamba2
Signed-off-by: Nune <ntadevosyan@nvidia.com>
1 parent de80996 commit b7d721e

File tree

2 files changed

+26
-16
lines changed

2 files changed

+26
-16
lines changed

mamba_ssm/modules/mamba2.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def __init__(
9292
self.use_mem_eff_path = use_mem_eff_path
9393
self.layer_idx = layer_idx
9494
self.ssm_state = None
95+
self.conv_state = None
9596

9697
# Order: [z, x, B, C, dt]
9798
d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
@@ -161,7 +162,6 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param
161162
Returns: same shape as u
162163
"""
163164
seqlen_og = seqlen
164-
#import pdb; pdb.set_trace()
165165
cache_device = self.in_proj.weight.device
166166
cache_dtype = self.in_proj.weight.dtype
167167
if seqlen is None:
@@ -170,18 +170,11 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param
170170
batch_seqlen, dim = u.shape
171171
batch = batch_seqlen // seqlen
172172

173-
should_cache_states = inference_params is None and cu_seqlens is None
173+
should_cache_states = inference_params is not None
174174
cached_state = self._maybe_get_cached_state(batch, cache_device, cache_dtype) if should_cache_states else None
175175

176176
conv_state, ssm_state = None, None
177-
if inference_params is not None:
178-
inference_batch = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch
179-
conv_state, ssm_state = self._get_states_from_cache(inference_params, inference_batch)
180-
if inference_params.seqlen_offset > 0:
181-
# The states are updated inplace
182-
out, _, _ = self.step(u, conv_state, ssm_state)
183-
return out
184-
177+
185178
zxbcdt = self.in_proj(u) # (B, L, d_in_proj) or (B * L, d_in_proj)
186179
if seqlen_og is not None:
187180
zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen)
@@ -227,34 +220,37 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param
227220
[d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads],
228221
dim=-1
229222
)
223+
224+
assert self.activation in ["silu", "swish"]
225+
conv_state = self._prepare_conv_state(xBC.transpose(1, 2), batch)
230226
if conv_state is not None:
231227
if cu_seqlens is None:
232228
# If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
233229
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
234230
xBC_t = rearrange(xBC, "b l d -> b d l")
235-
conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W)
231+
self.conv_state = F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0)) # Update state (B D W)
236232
else:
237233
assert causal_conv1d_varlen_states is not None, "varlen inference requires causal_conv1d package"
238234
assert batch == 1, "varlen inference only supports batch dimension 1"
239235
conv_varlen_states = causal_conv1d_varlen_states(
240236
xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1]
241237
)
242238
conv_state.copy_(conv_varlen_states)
243-
assert self.activation in ["silu", "swish"]
244239
if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
245240
assert seq_idx is None, "varlen conv1d requires the causal_conv1d package"
246241
xBC = self.act(
247242
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, :-(self.d_conv - 1)]
248243
) # (B, L, self.d_ssm + 2 * ngroups * d_state)
249244
else:
250-
xBC = causal_conv1d_fn(
245+
xBC = causal_conv1d_update(
251246
xBC.transpose(1, 2),
252-
rearrange(self.conv1d.weight, "d 1 w -> d w"),
247+
conv_state=conv_state,
248+
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
253249
bias=self.conv1d.bias,
254250
activation=self.activation,
255-
seq_idx=seq_idx,
256251
).transpose(1, 2)
257252
x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
253+
258254
return_varlen_states = cu_seqlens is not None and inference_params is not None
259255
initial_states = ssm_state if ssm_state is not None else cached_state
260256
return_final_states = (ssm_state is not None) or should_cache_states
@@ -409,6 +405,7 @@ def _get_states_from_cache(self, inference_params, batch_size, initialize_states
409405

410406
def reset_cache(self):
411407
self.ssm_state = None
408+
self.conv_state = None
412409

413410
def _maybe_get_cached_state(self, batch_size, device, dtype):
414411
if self.ssm_state is None:
@@ -427,3 +424,17 @@ def _update_cache_state(self, new_state):
427424
if cache_state.dtype != target_dtype:
428425
cache_state = cache_state.to(dtype=target_dtype)
429426
self.ssm_state = cache_state.contiguous()
427+
428+
def _prepare_conv_state(self, x, batch_size):
429+
state_len = self.d_conv - 1
430+
if state_len <= 0:
431+
self.conv_state = None
432+
return None
433+
if (
434+
self.conv_state is None
435+
or self.conv_state.shape[0] != batch_size
436+
or self.conv_state.shape[1] != x.shape[1]
437+
):
438+
self.conv_state = x.new_zeros(batch_size, x.shape[1], state_len)
439+
return self.conv_state
440+

mamba_ssm/modules/mamba2_simple.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ def forward(self, u, seq_idx=None):
127127
Returns: same shape as u
128128
"""
129129
batch, seqlen, dim = u.shape
130-
import pdb; pdb.set_trace()
131130
zxbcdt = self.in_proj(u) # (B, L, d_in_proj)
132131
A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state)
133132
initial_states=repeat(self.init_states, "... -> b ...", b=batch) if self.learnable_init_states else None

0 commit comments

Comments
 (0)