Skip to content

Commit 6da252d

Browse files
committed
Update mamba.py
Removed mamba2 support due to instability
1 parent 173a7c7 commit 6da252d

File tree

1 file changed

+6
-16
lines changed

1 file changed

+6
-16
lines changed

bayesflow/wrappers/mamba/mamba.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import keras
44
# from keras.saving import register_keras_serializable as serializable
55
try:
6-
from mamba_ssm import Mamba, Mamba2
6+
from mamba_ssm import Mamba
77
except ImportError:
88
print("Mamba Wrapper is not available")
99

@@ -14,7 +14,7 @@
1414
class MambaSSM(SummaryNetwork):
1515
def __init__(
1616
self,
17-
ssm_dim: int,
17+
feature_dim: int,
1818
summary_dim: int = 8,
1919
mamba_blocks: int = 2,
2020
state_dim: int = 16,
@@ -24,19 +24,19 @@ def __init__(
2424
dt_max: float = 0.1,
2525
pooling: bool = True,
2626
dropout: int | float | None = 0.5,
27-
mamba_version: int = 2,
2827
device: str = "cuda",
29-
d_ssm: int = 1,
3028
**kwargs
3129
):
3230
"""
3331
A time-series summarization network using Mamba-based State Space Models (SSM).
3432
This model processes sequential input data using the Mamba SSM layer, followed by
3533
optional pooling, dropout, and a dense layer for extracting summary statistics.
34+
35+
Mamba2 support currently unabailble due to stability issues
3636
3737
Parameters
3838
----------
39-
ssm_dim : int
39+
feature_dim : int
4040
The dimensionality of the Mamba SSM model.
4141
summary_dim : int, optional
4242
The output dimensionality of the summary statistics layer (default is 8).
@@ -56,8 +56,6 @@ def __init__(
5656
Whether to apply global average pooling (default is True).
5757
dropout : int, float, or None, optional
5858
Dropout rate applied before the summary layer (default is 0.5).
59-
mamba_version : int, optional
60-
The version of Mamba to apply (default is 2).
6159
device : str, optional
6260
The computing device. Currently, only "cuda" is supported (default is "cuda").
6361
**kwargs : dict
@@ -68,15 +66,7 @@ def __init__(
6866
if device != "cuda":
6967
raise NotImplementedError("MambaSSM currently only supports cuda")
7068

71-
if mamba_version == 1:
72-
self.mamba_blocks = [Mamba(d_model=ssm_dim, d_state=state_dim, d_conv=conv_dim, expand=expand, dt_min=dt_min, dt_max=dt_max).to(device) for _ in range(mamba_blocks)]
73-
elif mamba_version == 2:
74-
self.mamba_blocks = [Mamba2(d_model=ssm_dim, d_state=state_dim, d_conv=conv_dim, expand=expand, dt_min=dt_min, dt_max=dt_max, d_ssm=d_ssm).to(device) for _ in range(mamba_blocks)]
75-
else:
76-
raise NotImplementedError("Mamba version must be 1 or 2")
77-
78-
79-
69+
self.mamba_blocks = [Mamba(d_model=feature_dim, d_state=state_dim, d_conv=conv_dim, expand=expand, dt_min=dt_min, dt_max=dt_max).to(device) for _ in range(mamba_blocks)]
8070

8171
self.layernorm = keras.layers.LayerNormalization(axis=-1)
8272

0 commit comments

Comments
 (0)