Skip to content

Commit 62cf5b5

Browse files
Enable Mamba as Summary Nets (#398)
* Add first wrapper: Mamba * Finalize wrapper * Refactor mamba wrapper * Fix import
1 parent 557a08d commit 62cf5b5

File tree

7 files changed

+285
-0
lines changed

7 files changed

+285
-0
lines changed

bayesflow/wrappers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .mamba import Mamba
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .mamba import Mamba

bayesflow/wrappers/mamba/mamba.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
from collections.abc import Sequence
2+
3+
import keras
4+
from keras.saving import register_keras_serializable as serializable
5+
6+
from bayesflow.networks.summary_network import SummaryNetwork
7+
from bayesflow.types import Tensor
8+
from bayesflow.utils import keras_kwargs
9+
from bayesflow.utils.decorators import sanitize_input_shape
10+
11+
from .mamba_block import MambaBlock
12+
13+
14+
@serializable("bayesflow.wrappers")
15+
class Mamba(SummaryNetwork):
16+
"""
17+
Wraps a sequence of Mamba modules using the simple Mamba module from:
18+
https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py
19+
20+
Copyright (c) 2023, Tri Dao, Albert Gu.
21+
22+
Example usage in a BayesFlow workflow as a summary network:
23+
24+
`summary_net = bayesflow.wrappers.Mamba(summary_dim=32)`
25+
"""
26+
27+
def __init__(
28+
self,
29+
summary_dim: int = 16,
30+
feature_dims: Sequence[int] = (64, 64),
31+
state_dims: Sequence[int] = (64, 64),
32+
conv_dims: Sequence[int] = (64, 64),
33+
expand_dims: Sequence[int] = (2, 2),
34+
bidirectional: bool = True,
35+
dt_min: float = 0.001,
36+
dt_max: float = 0.1,
37+
dropout: float = 0.05,
38+
device: str = "cuda",
39+
**kwargs,
40+
):
41+
"""
42+
A time-series summarization network using Mamba-based State Space Models (SSM). This model processes
43+
sequential input data using a sequence of Mamba SSM layers (determined by the length of the tuples),
44+
followed by optional pooling, dropout, and a dense layer for extracting summary statistics.
45+
46+
Parameters
47+
----------
48+
summary_dim : Sequence[int], optional
49+
The output dimensionality of the summary statistics layer (default is 16).
50+
feature_dims : Sequence[int], optional
51+
The feature dimension for each mamba block, default is (64, 64),
52+
state_dims : Sequence[int], optional
53+
The dimensionality of the internal state in each Mamba block, default is (64, 64)
54+
conv_dims : Sequence[int], optional
55+
The dimensionality of the convolutional layer in each Mamba block, default is (32, 32)
56+
expand_dims : Sequence[int], optional
57+
The expansion factors for the hidden state in each Mamba block, default is (2, 2)
58+
dt_min : float, optional
59+
Minimum dynamic state evolution over time (default is 0.001).
60+
dt_max : float, optional
61+
Maximum dynamic state evolution over time (default is 0.1).
62+
pooling : bool, optional
63+
Whether to apply global average pooling (default is True).
64+
dropout : int, float, or None, optional
65+
Dropout rate applied before the summary layer (default is 0.5).
66+
dropout: float, optional
67+
Dropout probability; dropout is applied to the pooled summary vector.
68+
device : str, optional
69+
The computing device. Currently, only "cuda" is supported (default is "cuda").
70+
**kwargs :
71+
Additional keyword arguments passed to the `SummaryNetwork` parent class.
72+
"""
73+
74+
super().__init__(**keras_kwargs(kwargs))
75+
76+
if device != "cuda":
77+
raise NotImplementedError("MambaSSM only supports cuda as `device`.")
78+
79+
self.mamba_blocks = []
80+
for feature_dim, state_dim, conv_dim, expand in zip(feature_dims, state_dims, conv_dims, expand_dims):
81+
mamba = MambaBlock(feature_dim, state_dim, conv_dim, expand, bidirectional, dt_min, dt_max, device)
82+
self.mamba_blocks.append(mamba)
83+
84+
self.pooling_layer = keras.layers.GlobalAveragePooling1D()
85+
self.dropout = keras.layers.Dropout(dropout)
86+
self.summary_stats = keras.layers.Dense(summary_dim)
87+
88+
def call(self, time_series: Tensor, training: bool = True, **kwargs) -> Tensor:
89+
"""
90+
Apply a sequence of Mamba blocks, followed by pooling, dropout, and summary statistics.
91+
92+
Parameters
93+
----------
94+
time_series : Tensor
95+
Input tensor representing the time series data, typically of shape
96+
(batch_size, sequence_length, feature_dim).
97+
training : bool, optional
98+
Whether the model is in training mode (default is True). Affects behavior of
99+
layers like dropout.
100+
**kwargs : dict
101+
Additional keyword arguments (not used in this method).
102+
103+
Returns
104+
-------
105+
Tensor
106+
Output tensor after applying Mamba blocks, pooling, dropout, and summary statistics.
107+
"""
108+
109+
summary = time_series
110+
for mamba_block in self.mamba_blocks:
111+
summary = mamba_block(summary, training=training)
112+
113+
summary = self.pooling_layer(summary)
114+
summary = self.dropout(summary, training=training)
115+
summary = self.summary_stats(summary)
116+
117+
return summary
118+
119+
@sanitize_input_shape
120+
def build(self, input_shape):
121+
super().build(input_shape)
122+
self.call(keras.ops.zeros(input_shape))
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import keras
2+
from keras.saving import register_keras_serializable as serializable
3+
4+
from bayesflow.types import Tensor
5+
from bayesflow.utils import logging, keras_kwargs
6+
from bayesflow.utils.decorators import sanitize_input_shape
7+
8+
try:
9+
from mamba_ssm import Mamba
10+
except ImportError:
11+
logging.error("Mamba class is not available. Please, install the mamba-ssm library via `pip install mamba-ssm`.")
12+
13+
14+
@serializable("bayesflow.wrappers")
15+
class MambaBlock(keras.Layer):
16+
"""
17+
Wraps the original Mamba module from, with added functionality for bidirectional processing:
18+
https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py
19+
20+
Copyright (c) 2023, Tri Dao, Albert Gu.
21+
"""
22+
23+
def __init__(
24+
self,
25+
state_dim: int,
26+
conv_dim: int,
27+
feature_dim: int = 16,
28+
expand: int = 2,
29+
bidirectional: bool = True,
30+
dt_min: float = 0.001,
31+
dt_max: float = 0.1,
32+
device: str = "cuda",
33+
**kwargs,
34+
):
35+
"""
36+
A Keras layer implementing a Mamba-based sequence processing block.
37+
38+
This layer applies a Mamba model for sequence modeling, preceded by a
39+
convolutional projection and followed by layer normalization.
40+
41+
Parameters
42+
----------
43+
state_dim : int
44+
The dimension of the state space in the Mamba model.
45+
conv_dim : int
46+
The dimension of the convolutional layer used in Mamba.
47+
feature_dim : int, optional
48+
The feature dimension for input projection and Mamba processing (default is 16).
49+
expand : int, optional
50+
Expansion factor for Mamba's internal dimension (default is 1).
51+
dt_min : float, optional
52+
Minimum delta time for Mamba (default is 0.001).
53+
dt_max : float, optional
54+
Maximum delta time for Mamba (default is 0.1).
55+
device : str, optional
56+
The device to which the Mamba model is moved, typically "cuda" or "cpu" (default is "cuda").
57+
**kwargs :
58+
Additional keyword arguments passed to the `keras.layers.Layer` initializer.
59+
"""
60+
61+
super().__init__(**keras_kwargs(kwargs))
62+
63+
if keras.backend.backend() != "torch":
64+
raise RuntimeError("Mamba is only available using torch backend.")
65+
66+
self.bidirectional = bidirectional
67+
68+
self.mamba = Mamba(
69+
d_model=feature_dim, d_state=state_dim, d_conv=conv_dim, expand=expand, dt_min=dt_min, dt_max=dt_max
70+
).to(device)
71+
72+
self.input_projector = keras.layers.Conv1D(
73+
feature_dim,
74+
kernel_size=1,
75+
strides=1,
76+
)
77+
self.layer_norm = keras.layers.LayerNormalization()
78+
79+
def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor:
80+
"""
81+
Applies the Mamba layer to the input tensor `x`, optionally in a bidirectional manner.
82+
83+
Parameters
84+
----------
85+
x : Tensor
86+
Input tensor of shape `(batch_size, sequence_length, input_dim)`.
87+
training : bool, optional
88+
Whether the layer should behave in training mode (e.g., applying dropout). Default is False.
89+
**kwargs : dict
90+
Additional keyword arguments passed to the internal `_call` method.
91+
92+
Returns
93+
-------
94+
Tensor
95+
Output tensor of shape `(batch_size, sequence_length, feature_dim)` if unidirectional,
96+
or `(batch_size, sequence_length, 2 * feature_dim)` if bidirectional.
97+
"""
98+
99+
out_forward = self._call(x, training=training, **kwargs)
100+
if self.bidirectional:
101+
out_backward = self._call(keras.ops.flip(x, axis=-2), training=training, **kwargs)
102+
return keras.ops.concatenate((out_forward, out_backward), axis=-1)
103+
return out_forward
104+
105+
def _call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor:
106+
x = self.input_projector(x)
107+
h = self.mamba(x)
108+
out = self.layer_norm(h + x, training=training, **kwargs)
109+
return out
110+
111+
@sanitize_input_shape
112+
def build(self, input_shape):
113+
super().build(input_shape)
114+
self.call(keras.ops.zeros(input_shape))

tests/test_wrappers/__init__.py

Whitespace-only changes.

tests/test_wrappers/conftest.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import pytest
2+
3+
4+
@pytest.fixture()
5+
def inference_network():
6+
from bayesflow.networks import CouplingFlow
7+
8+
return CouplingFlow(depth=2)
9+
10+
11+
@pytest.fixture()
12+
def random_time_series():
13+
import keras
14+
15+
return keras.random.normal(shape=(2, 80, 2))
16+
17+
18+
@pytest.fixture()
19+
def mamba_summary_network():
20+
from bayesflow.wrappers.mamba import Mamba
21+
22+
return Mamba(summary_dim=4, feature_dims=(2, 2), state_dims=(4, 4), conv_dims=(8, 8))

tests/test_wrappers/test_mamba.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import pytest
2+
3+
import bayesflow as bf
4+
5+
6+
@pytest.mark.torch
7+
def test_mamba_summary(random_time_series, mamba_summary_network):
8+
out = mamba_summary_network(random_time_series)
9+
# Batch size 2, summary dim 4
10+
assert out.shape == (2, 4)
11+
12+
13+
@pytest.mark.torch
14+
def test_mamba_trains(random_time_series, inference_network, mamba_summary_network):
15+
workflow = bf.BasicWorkflow(
16+
inference_network=inference_network,
17+
summary_network=mamba_summary_network,
18+
inference_variables=["parameters"],
19+
summary_variables=["observables"],
20+
simulator=bf.simulators.SIR(),
21+
)
22+
23+
history = workflow.fit_online(epochs=2, batch_size=8, num_batches_per_epoch=2)
24+
assert "loss" in list(history.history.keys())
25+
assert len(history.history["loss"]) == 2

0 commit comments

Comments
 (0)