Skip to content

Commit c1f6bb9

Browse files
committed
Finalize wrapper
1 parent 4924791 commit c1f6bb9

File tree

1 file changed

+25
-6
lines changed

1 file changed

+25
-6
lines changed

bayesflow/wrappers/mamba.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,22 @@
1616

1717
@serializable("bayesflow.wrappers")
1818
class MambaBlock(keras.Layer):
19+
"""
20+
Wraps the original Mamba module from, with added functionality for bidirectional processing:
21+
https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py
22+
23+
Copyright (c) 2023, Tri Dao, Albert Gu.
24+
"""
25+
1926
def __init__(
2027
self,
2128
state_dim: int,
2229
conv_dim: int,
2330
feature_dim: int = 16,
24-
expand: int = 1,
25-
dt_min=0.001,
26-
dt_max=0.1,
31+
expand: int = 2,
32+
bidirectional: bool = True,
33+
dt_min: float = 0.001,
34+
dt_max: float = 0.1,
2735
device: str = "cuda",
2836
**kwargs,
2937
):
@@ -58,6 +66,8 @@ def __init__(
5866
if keras.backend.backend() != "torch":
5967
raise EnvironmentError("Mamba is only available using torch backend.")
6068

69+
self.bidirectional = bidirectional
70+
6171
self.mamba = Mamba(
6272
d_model=feature_dim, d_state=state_dim, d_conv=conv_dim, expand=expand, dt_min=dt_min, dt_max=dt_max
6373
).to(device)
@@ -70,6 +80,13 @@ def __init__(
7080
self.layer_norm = keras.layers.LayerNormalization()
7181

7282
def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor:
83+
out_forward = self._call(x, training=training, **kwargs)
84+
if self.bidirectional:
85+
out_backward = self._call(keras.ops.flip(x, axis=1), training=training, **kwargs)
86+
return keras.ops.concatenate((out_forward, out_backward), axis=-1)
87+
return out_forward
88+
89+
def _call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor:
7390
x = self.input_projector(x)
7491
h = self.mamba(x)
7592
out = self.layer_norm(h + x, training=training, **kwargs)
@@ -84,7 +101,7 @@ def build(self, input_shape):
84101
@serializable("bayesflow.wrappers")
85102
class MambaSSM(SummaryNetwork):
86103
"""
87-
Wraps the original Mamba module from:
104+
Wraps a sequence of Mamba modules using the simple Mamba module from:
88105
https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py
89106
90107
Copyright (c) 2023, Tri Dao, Albert Gu.
@@ -94,9 +111,10 @@ def __init__(
94111
self,
95112
summary_dim: int = 16,
96113
feature_dims: Sequence[int] = (64, 64),
97-
state_dims: Sequence[int] = (128, 128),
114+
state_dims: Sequence[int] = (64, 64),
98115
conv_dims: Sequence[int] = (64, 64),
99116
expand_dims: Sequence[int] = (2, 2),
117+
bidirectional: bool = True,
100118
dt_min: float = 0.001,
101119
dt_max: float = 0.1,
102120
dropout: float = 0.05,
@@ -143,7 +161,8 @@ def __init__(
143161

144162
self.mamba_blocks = []
145163
for feature_dim, state_dim, conv_dim, expand in zip(feature_dims, state_dims, conv_dims, expand_dims):
146-
self.mamba_blocks.append(MambaBlock(feature_dim, state_dim, conv_dim, expand, dt_min, dt_max, device))
164+
mamba = MambaBlock(feature_dim, state_dim, conv_dim, expand, bidirectional, dt_min, dt_max, device)
165+
self.mamba_blocks.append(mamba)
147166

148167
self.pooling_layer = keras.layers.GlobalAveragePooling1D()
149168
self.dropout = keras.layers.Dropout(dropout)

0 commit comments

Comments
 (0)