Skip to content

Commit 43f561a

Browse files
committed
Refactor mamba wrapper
1 parent c1f6bb9 commit 43f561a

File tree

4 files changed

+124
-92
lines changed

4 files changed

+124
-92
lines changed

bayesflow/wrappers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .mamba import MambaSSM
1+
from .mamba import Mamba
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from mamba import Mamba

bayesflow/wrappers/mamba.py renamed to bayesflow/wrappers/mamba/mamba.py

Lines changed: 8 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -5,106 +5,23 @@
55

66
from bayesflow.networks.summary_network import SummaryNetwork
77
from bayesflow.types import Tensor
8-
from bayesflow.utils import logging
8+
from bayesflow.utils import keras_kwargs
99
from bayesflow.utils.decorators import sanitize_input_shape
1010

11-
try:
12-
from mamba_ssm import Mamba
13-
except ImportError:
14-
logging.error("Mamba class is not available. Please, install the mamba-ssm library via `pip install mamba-ssm`.")
11+
from .mamba_block import MambaBlock
1512

1613

1714
@serializable("bayesflow.wrappers")
18-
class MambaBlock(keras.Layer):
15+
class Mamba(SummaryNetwork):
1916
"""
20-
Wraps the original Mamba module from, with added functionality for bidirectional processing:
17+
Wraps a sequence of Mamba modules using the simple Mamba module from:
2118
https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py
2219
2320
Copyright (c) 2023, Tri Dao, Albert Gu.
24-
"""
25-
26-
def __init__(
27-
self,
28-
state_dim: int,
29-
conv_dim: int,
30-
feature_dim: int = 16,
31-
expand: int = 2,
32-
bidirectional: bool = True,
33-
dt_min: float = 0.001,
34-
dt_max: float = 0.1,
35-
device: str = "cuda",
36-
**kwargs,
37-
):
38-
"""
39-
A Keras layer implementing a Mamba-based sequence processing block.
40-
41-
This layer applies a Mamba model for sequence modeling, preceded by a
42-
convolutional projection and followed by layer normalization.
43-
44-
Parameters
45-
----------
46-
state_dim : int
47-
The dimension of the state space in the Mamba model.
48-
conv_dim : int
49-
The dimension of the convolutional layer used in Mamba.
50-
feature_dim : int, optional
51-
The feature dimension for input projection and Mamba processing (default is 16).
52-
expand : int, optional
53-
Expansion factor for Mamba's internal dimension (default is 1).
54-
dt_min : float, optional
55-
Minimum delta time for Mamba (default is 0.001).
56-
dt_max : float, optional
57-
Maximum delta time for Mamba (default is 0.1).
58-
device : str, optional
59-
The device to which the Mamba model is moved, typically "cuda" or "cpu" (default is "cuda").
60-
**kwargs : dict
61-
Additional keyword arguments passed to the `keras.layers.Layer` initializer.
62-
"""
63-
64-
super().__init__(**kwargs)
65-
66-
if keras.backend.backend() != "torch":
67-
raise EnvironmentError("Mamba is only available using torch backend.")
68-
69-
self.bidirectional = bidirectional
70-
71-
self.mamba = Mamba(
72-
d_model=feature_dim, d_state=state_dim, d_conv=conv_dim, expand=expand, dt_min=dt_min, dt_max=dt_max
73-
).to(device)
74-
75-
self.input_projector = keras.layers.Conv1D(
76-
feature_dim,
77-
kernel_size=1,
78-
strides=1,
79-
)
80-
self.layer_norm = keras.layers.LayerNormalization()
81-
82-
def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor:
83-
out_forward = self._call(x, training=training, **kwargs)
84-
if self.bidirectional:
85-
out_backward = self._call(keras.ops.flip(x, axis=1), training=training, **kwargs)
86-
return keras.ops.concatenate((out_forward, out_backward), axis=-1)
87-
return out_forward
88-
89-
def _call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor:
90-
x = self.input_projector(x)
91-
h = self.mamba(x)
92-
out = self.layer_norm(h + x, training=training, **kwargs)
93-
return out
94-
95-
@sanitize_input_shape
96-
def build(self, input_shape):
97-
super().build(input_shape)
98-
self.call(keras.ops.zeros(input_shape))
9921
22+
Example usage in a BayesFlow workflow as a summary network:
10023
101-
@serializable("bayesflow.wrappers")
102-
class MambaSSM(SummaryNetwork):
103-
"""
104-
Wraps a sequence of Mamba modules using the simple Mamba module from:
105-
https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py
106-
107-
Copyright (c) 2023, Tri Dao, Albert Gu.
24+
`summary_net = bayesflow.wrappers.Mamba(summary_dim=32)`
10825
"""
10926

11027
def __init__(
@@ -150,11 +67,11 @@ def __init__(
15067
Dropout probability; dropout is applied to the pooled summary vector.
15168
device : str, optional
15269
The computing device. Currently, only "cuda" is supported (default is "cuda").
153-
**kwargs : dict
70+
**kwargs :
15471
Additional keyword arguments passed to the `SummaryNetwork` parent class.
15572
"""
15673

157-
super().__init__(**kwargs)
74+
super().__init__(**keras_kwargs(kwargs))
15875

15976
if device != "cuda":
16077
raise NotImplementedError("MambaSSM only supports cuda as `device`.")
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import keras
2+
from keras.saving import register_keras_serializable as serializable
3+
4+
from bayesflow.types import Tensor
5+
from bayesflow.utils import logging, keras_kwargs
6+
from bayesflow.utils.decorators import sanitize_input_shape
7+
8+
try:
9+
from mamba_ssm import Mamba
10+
except ImportError:
11+
logging.error("Mamba class is not available. Please, install the mamba-ssm library via `pip install mamba-ssm`.")
12+
13+
14+
@serializable("bayesflow.wrappers")
15+
class MambaBlock(keras.Layer):
16+
"""
17+
Wraps the original Mamba module from, with added functionality for bidirectional processing:
18+
https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py
19+
20+
Copyright (c) 2023, Tri Dao, Albert Gu.
21+
"""
22+
23+
def __init__(
24+
self,
25+
state_dim: int,
26+
conv_dim: int,
27+
feature_dim: int = 16,
28+
expand: int = 2,
29+
bidirectional: bool = True,
30+
dt_min: float = 0.001,
31+
dt_max: float = 0.1,
32+
device: str = "cuda",
33+
**kwargs,
34+
):
35+
"""
36+
A Keras layer implementing a Mamba-based sequence processing block.
37+
38+
This layer applies a Mamba model for sequence modeling, preceded by a
39+
convolutional projection and followed by layer normalization.
40+
41+
Parameters
42+
----------
43+
state_dim : int
44+
The dimension of the state space in the Mamba model.
45+
conv_dim : int
46+
The dimension of the convolutional layer used in Mamba.
47+
feature_dim : int, optional
48+
The feature dimension for input projection and Mamba processing (default is 16).
49+
expand : int, optional
50+
Expansion factor for Mamba's internal dimension (default is 1).
51+
dt_min : float, optional
52+
Minimum delta time for Mamba (default is 0.001).
53+
dt_max : float, optional
54+
Maximum delta time for Mamba (default is 0.1).
55+
device : str, optional
56+
The device to which the Mamba model is moved, typically "cuda" or "cpu" (default is "cuda").
57+
**kwargs :
58+
Additional keyword arguments passed to the `keras.layers.Layer` initializer.
59+
"""
60+
61+
super().__init__(**keras_kwargs(kwargs))
62+
63+
if keras.backend.backend() != "torch":
64+
raise RuntimeError("Mamba is only available using torch backend.")
65+
66+
self.bidirectional = bidirectional
67+
68+
self.mamba = Mamba(
69+
d_model=feature_dim, d_state=state_dim, d_conv=conv_dim, expand=expand, dt_min=dt_min, dt_max=dt_max
70+
).to(device)
71+
72+
self.input_projector = keras.layers.Conv1D(
73+
feature_dim,
74+
kernel_size=1,
75+
strides=1,
76+
)
77+
self.layer_norm = keras.layers.LayerNormalization()
78+
79+
def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor:
80+
"""
81+
Applies the Mamba layer to the input tensor `x`, optionally in a bidirectional manner.
82+
83+
Parameters
84+
----------
85+
x : Tensor
86+
Input tensor of shape `(batch_size, sequence_length, input_dim)`.
87+
training : bool, optional
88+
Whether the layer should behave in training mode (e.g., applying dropout). Default is False.
89+
**kwargs : dict
90+
Additional keyword arguments passed to the internal `_call` method.
91+
92+
Returns
93+
-------
94+
Tensor
95+
Output tensor of shape `(batch_size, sequence_length, feature_dim)` if unidirectional,
96+
or `(batch_size, sequence_length, 2 * feature_dim)` if bidirectional.
97+
"""
98+
99+
out_forward = self._call(x, training=training, **kwargs)
100+
if self.bidirectional:
101+
out_backward = self._call(keras.ops.flip(x, axis=-2), training=training, **kwargs)
102+
return keras.ops.concatenate((out_forward, out_backward), axis=-1)
103+
return out_forward
104+
105+
def _call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor:
106+
x = self.input_projector(x)
107+
h = self.mamba(x)
108+
out = self.layer_norm(h + x, training=training, **kwargs)
109+
return out
110+
111+
@sanitize_input_shape
112+
def build(self, input_shape):
113+
super().build(input_shape)
114+
self.call(keras.ops.zeros(input_shape))

0 commit comments

Comments
 (0)