Skip to content

Commit cf97f6c

Browse files
Fix mamba regression (#39728)
* fix mamba regression * fix compile test
1 parent 66984ed commit cf97f6c

File tree

3 files changed

+16
-5
lines changed

3 files changed

+16
-5
lines changed

src/transformers/models/falcon_mamba/configuration_falcon_mamba.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,12 @@ def __init__(
141141
self.layer_norm_epsilon = layer_norm_epsilon
142142
self.conv_kernel = conv_kernel
143143
self.expand = expand
144-
self.intermediate_size = int(expand * self.hidden_size)
144+
# This is needed since mamba overrides the intermediate_size attribute
145+
self.intermediate_size = (
146+
int(expand * self.hidden_size)
147+
if kwargs.get("intermediate_size") is None
148+
else kwargs.get("intermediate_size")
149+
)
145150
self.bos_token_id = bos_token_id
146151
self.eos_token_id = eos_token_id
147152
self.pad_token_id = pad_token_id

src/transformers/models/falcon_mamba/modular_falcon_mamba.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,12 @@ def __init__(
192192
**kwargs,
193193
)
194194
self.mixer_rms_eps = mixer_rms_eps
195+
# This is needed since mamba overrides the intermediate_size attribute
196+
self.intermediate_size = (
197+
int(expand * self.hidden_size)
198+
if kwargs.get("intermediate_size") is None
199+
else kwargs.get("intermediate_size")
200+
)
195201

196202

197203
class FalconMambaCache(MambaCache):

src/transformers/models/mamba/modeling_mamba.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@
3939

4040
logger = logging.get_logger(__name__)
4141

42+
if is_mambapy_available():
43+
from mambapy.pscan import pscan
44+
else:
45+
pscan = None
4246

4347
if is_mamba_ssm_available():
4448
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
@@ -330,10 +334,6 @@ def cuda_kernels_forward(
330334

331335
# fmt: off
332336
def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.LongTensor] = None):
333-
if is_mambapy_available():
334-
from mambapy.pscan import pscan
335-
else:
336-
pscan = None
337337
batch_size, seq_len, _ = input_states.shape
338338
dtype = input_states.dtype
339339
# 1. Gated MLP's linear projection

0 commit comments

Comments
 (0)