Skip to content

Commit 173a7c7

Browse files
committed
Update mamba.py
Replaced generator with specific mamba constructors due to new input args. Added new hidden ssm dim arg.
1 parent 4f9d592 commit 173a7c7

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

bayesflow/wrappers/mamba/mamba.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ def __init__(
2525
pooling: bool = True,
2626
dropout: int | float | None = 0.5,
2727
mamba_version: int = 2,
28-
device: str = "cuda",
28+
device: str = "cuda",
29+
d_ssm: int = 1,
2930
**kwargs
3031
):
3132
"""
@@ -68,13 +69,14 @@ def __init__(
6869
raise NotImplementedError("MambaSSM currently only supports cuda")
6970

7071
if mamba_version == 1:
71-
mamba_gen = Mamba
72+
self.mamba_blocks = [Mamba(d_model=ssm_dim, d_state=state_dim, d_conv=conv_dim, expand=expand, dt_min=dt_min, dt_max=dt_max).to(device) for _ in range(mamba_blocks)]
7273
elif mamba_version == 2:
73-
mamba_gen = Mamba2
74+
self.mamba_blocks = [Mamba2(d_model=ssm_dim, d_state=state_dim, d_conv=conv_dim, expand=expand, dt_min=dt_min, dt_max=dt_max, d_ssm=d_ssm).to(device) for _ in range(mamba_blocks)]
7475
else:
7576
raise NotImplementedError("Mamba version must be 1 or 2")
7677

77-
self.mamba_blocks = [mamba_gen(d_model=ssm_dim, d_state=state_dim, d_conv=conv_dim, expand=expand, dt_min=dt_min, dt_max=dt_max).to(device) for _ in range(mamba_blocks)]
78+
79+
7880

7981
self.layernorm = keras.layers.LayerNormalization(axis=-1)
8082

0 commit comments

Comments
 (0)