Skip to content

Commit 4924791

Browse files
committed
Add first wrapper: Mamba
1 parent 1284694 commit 4924791

File tree

2 files changed

+187
-0
lines changed

2 files changed

+187
-0
lines changed

bayesflow/wrappers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .mamba import MambaSSM

bayesflow/wrappers/mamba.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
from collections.abc import Sequence
2+
3+
import keras
4+
from keras.saving import register_keras_serializable as serializable
5+
6+
from bayesflow.networks.summary_network import SummaryNetwork
7+
from bayesflow.types import Tensor
8+
from bayesflow.utils import logging
9+
from bayesflow.utils.decorators import sanitize_input_shape
10+
11+
try:
12+
from mamba_ssm import Mamba
13+
except ImportError:
14+
logging.error("Mamba class is not available. Please, install the mamba-ssm library via `pip install mamba-ssm`.")
15+
16+
17+
@serializable("bayesflow.wrappers")
18+
class MambaBlock(keras.Layer):
19+
def __init__(
20+
self,
21+
state_dim: int,
22+
conv_dim: int,
23+
feature_dim: int = 16,
24+
expand: int = 1,
25+
dt_min=0.001,
26+
dt_max=0.1,
27+
device: str = "cuda",
28+
**kwargs,
29+
):
30+
"""
31+
A Keras layer implementing a Mamba-based sequence processing block.
32+
33+
This layer applies a Mamba model for sequence modeling, preceded by a
34+
convolutional projection and followed by layer normalization.
35+
36+
Parameters
37+
----------
38+
state_dim : int
39+
The dimension of the state space in the Mamba model.
40+
conv_dim : int
41+
The dimension of the convolutional layer used in Mamba.
42+
feature_dim : int, optional
43+
The feature dimension for input projection and Mamba processing (default is 16).
44+
expand : int, optional
45+
Expansion factor for Mamba's internal dimension (default is 1).
46+
dt_min : float, optional
47+
Minimum delta time for Mamba (default is 0.001).
48+
dt_max : float, optional
49+
Maximum delta time for Mamba (default is 0.1).
50+
device : str, optional
51+
The device to which the Mamba model is moved, typically "cuda" or "cpu" (default is "cuda").
52+
**kwargs : dict
53+
Additional keyword arguments passed to the `keras.layers.Layer` initializer.
54+
"""
55+
56+
super().__init__(**kwargs)
57+
58+
if keras.backend.backend() != "torch":
59+
raise EnvironmentError("Mamba is only available using torch backend.")
60+
61+
self.mamba = Mamba(
62+
d_model=feature_dim, d_state=state_dim, d_conv=conv_dim, expand=expand, dt_min=dt_min, dt_max=dt_max
63+
).to(device)
64+
65+
self.input_projector = keras.layers.Conv1D(
66+
feature_dim,
67+
kernel_size=1,
68+
strides=1,
69+
)
70+
self.layer_norm = keras.layers.LayerNormalization()
71+
72+
def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor:
73+
x = self.input_projector(x)
74+
h = self.mamba(x)
75+
out = self.layer_norm(h + x, training=training, **kwargs)
76+
return out
77+
78+
@sanitize_input_shape
79+
def build(self, input_shape):
80+
super().build(input_shape)
81+
self.call(keras.ops.zeros(input_shape))
82+
83+
84+
@serializable("bayesflow.wrappers")
85+
class MambaSSM(SummaryNetwork):
86+
"""
87+
Wraps the original Mamba module from:
88+
https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py
89+
90+
Copyright (c) 2023, Tri Dao, Albert Gu.
91+
"""
92+
93+
def __init__(
94+
self,
95+
summary_dim: int = 16,
96+
feature_dims: Sequence[int] = (64, 64),
97+
state_dims: Sequence[int] = (128, 128),
98+
conv_dims: Sequence[int] = (64, 64),
99+
expand_dims: Sequence[int] = (2, 2),
100+
dt_min: float = 0.001,
101+
dt_max: float = 0.1,
102+
dropout: float = 0.05,
103+
device: str = "cuda",
104+
**kwargs,
105+
):
106+
"""
107+
A time-series summarization network using Mamba-based State Space Models (SSM). This model processes
108+
sequential input data using a sequence of Mamba SSM layers (determined by the length of the tuples),
109+
followed by optional pooling, dropout, and a dense layer for extracting summary statistics.
110+
111+
Parameters
112+
----------
113+
summary_dim : Sequence[int], optional
114+
The output dimensionality of the summary statistics layer (default is 16).
115+
feature_dims : Sequence[int], optional
116+
The feature dimension for each mamba block, default is (64, 64),
117+
state_dims : Sequence[int], optional
118+
The dimensionality of the internal state in each Mamba block, default is (64, 64)
119+
conv_dims : Sequence[int], optional
120+
The dimensionality of the convolutional layer in each Mamba block, default is (32, 32)
121+
expand_dims : Sequence[int], optional
122+
The expansion factors for the hidden state in each Mamba block, default is (2, 2)
123+
dt_min : float, optional
124+
Minimum dynamic state evolution over time (default is 0.001).
125+
dt_max : float, optional
126+
Maximum dynamic state evolution over time (default is 0.1).
127+
pooling : bool, optional
128+
Whether to apply global average pooling (default is True).
129+
dropout : int, float, or None, optional
130+
Dropout rate applied before the summary layer (default is 0.5).
131+
dropout: float, optional
132+
Dropout probability; dropout is applied to the pooled summary vector.
133+
device : str, optional
134+
The computing device. Currently, only "cuda" is supported (default is "cuda").
135+
**kwargs : dict
136+
Additional keyword arguments passed to the `SummaryNetwork` parent class.
137+
"""
138+
139+
super().__init__(**kwargs)
140+
141+
if device != "cuda":
142+
raise NotImplementedError("MambaSSM only supports cuda as `device`.")
143+
144+
self.mamba_blocks = []
145+
for feature_dim, state_dim, conv_dim, expand in zip(feature_dims, state_dims, conv_dims, expand_dims):
146+
self.mamba_blocks.append(MambaBlock(feature_dim, state_dim, conv_dim, expand, dt_min, dt_max, device))
147+
148+
self.pooling_layer = keras.layers.GlobalAveragePooling1D()
149+
self.dropout = keras.layers.Dropout(dropout)
150+
self.summary_stats = keras.layers.Dense(summary_dim)
151+
152+
def call(self, time_series: Tensor, training: bool = True, **kwargs) -> Tensor:
153+
"""
154+
Apply a sequence of Mamba blocks, followed by pooling, dropout, and summary statistics.
155+
156+
Parameters
157+
----------
158+
time_series : Tensor
159+
Input tensor representing the time series data, typically of shape
160+
(batch_size, sequence_length, feature_dim).
161+
training : bool, optional
162+
Whether the model is in training mode (default is True). Affects behavior of
163+
layers like dropout.
164+
**kwargs : dict
165+
Additional keyword arguments (not used in this method).
166+
167+
Returns
168+
-------
169+
Tensor
170+
Output tensor after applying Mamba blocks, pooling, dropout, and summary statistics.
171+
"""
172+
173+
summary = time_series
174+
for mamba_block in self.mamba_blocks:
175+
summary = mamba_block(summary, training=training)
176+
177+
summary = self.pooling_layer(summary)
178+
summary = self.dropout(summary, training=training)
179+
summary = self.summary_stats(summary)
180+
181+
return summary
182+
183+
@sanitize_input_shape
184+
def build(self, input_shape):
185+
super().build(input_shape)
186+
self.call(keras.ops.zeros(input_shape))

0 commit comments

Comments
 (0)