Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bayesflow/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .mamba import Mamba
1 change: 1 addition & 0 deletions bayesflow/wrappers/mamba/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .mamba import Mamba
122 changes: 122 additions & 0 deletions bayesflow/wrappers/mamba/mamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from collections.abc import Sequence

import keras
from keras.saving import register_keras_serializable as serializable

from bayesflow.networks.summary_network import SummaryNetwork
from bayesflow.types import Tensor
from bayesflow.utils import keras_kwargs
from bayesflow.utils.decorators import sanitize_input_shape

from .mamba_block import MambaBlock


@serializable("bayesflow.wrappers")
class Mamba(SummaryNetwork):
"""
Wraps a sequence of Mamba modules using the simple Mamba module from:
https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py

Copyright (c) 2023, Tri Dao, Albert Gu.

Example usage in a BayesFlow workflow as a summary network:

`summary_net = bayesflow.wrappers.Mamba(summary_dim=32)`
"""

def __init__(
self,
summary_dim: int = 16,
feature_dims: Sequence[int] = (64, 64),
state_dims: Sequence[int] = (64, 64),
conv_dims: Sequence[int] = (64, 64),
expand_dims: Sequence[int] = (2, 2),
bidirectional: bool = True,
dt_min: float = 0.001,
dt_max: float = 0.1,
dropout: float = 0.05,
device: str = "cuda",
**kwargs,
):
"""
A time-series summarization network using Mamba-based State Space Models (SSM). This model processes
sequential input data using a sequence of Mamba SSM layers (determined by the length of the tuples),
followed by optional pooling, dropout, and a dense layer for extracting summary statistics.

Parameters
----------
summary_dim : Sequence[int], optional
The output dimensionality of the summary statistics layer (default is 16).
feature_dims : Sequence[int], optional
The feature dimension for each mamba block, default is (64, 64),
state_dims : Sequence[int], optional
The dimensionality of the internal state in each Mamba block, default is (64, 64)
conv_dims : Sequence[int], optional
The dimensionality of the convolutional layer in each Mamba block, default is (32, 32)
expand_dims : Sequence[int], optional
The expansion factors for the hidden state in each Mamba block, default is (2, 2)
dt_min : float, optional
Minimum dynamic state evolution over time (default is 0.001).
dt_max : float, optional
Maximum dynamic state evolution over time (default is 0.1).
pooling : bool, optional
Whether to apply global average pooling (default is True).
dropout : int, float, or None, optional
Dropout rate applied before the summary layer (default is 0.5).
dropout: float, optional
Dropout probability; dropout is applied to the pooled summary vector.
device : str, optional
The computing device. Currently, only "cuda" is supported (default is "cuda").
**kwargs :
Additional keyword arguments passed to the `SummaryNetwork` parent class.
"""

super().__init__(**keras_kwargs(kwargs))

if device != "cuda":
raise NotImplementedError("MambaSSM only supports cuda as `device`.")

self.mamba_blocks = []
for feature_dim, state_dim, conv_dim, expand in zip(feature_dims, state_dims, conv_dims, expand_dims):
mamba = MambaBlock(feature_dim, state_dim, conv_dim, expand, bidirectional, dt_min, dt_max, device)
self.mamba_blocks.append(mamba)

self.pooling_layer = keras.layers.GlobalAveragePooling1D()
self.dropout = keras.layers.Dropout(dropout)
self.summary_stats = keras.layers.Dense(summary_dim)

def call(self, time_series: Tensor, training: bool = True, **kwargs) -> Tensor:
"""
Apply a sequence of Mamba blocks, followed by pooling, dropout, and summary statistics.

Parameters
----------
time_series : Tensor
Input tensor representing the time series data, typically of shape
(batch_size, sequence_length, feature_dim).
training : bool, optional
Whether the model is in training mode (default is True). Affects behavior of
layers like dropout.
**kwargs : dict
Additional keyword arguments (not used in this method).

Returns
-------
Tensor
Output tensor after applying Mamba blocks, pooling, dropout, and summary statistics.
"""

summary = time_series
for mamba_block in self.mamba_blocks:
summary = mamba_block(summary, training=training)

summary = self.pooling_layer(summary)
summary = self.dropout(summary, training=training)
summary = self.summary_stats(summary)

return summary

@sanitize_input_shape
def build(self, input_shape):
super().build(input_shape)
self.call(keras.ops.zeros(input_shape))
114 changes: 114 additions & 0 deletions bayesflow/wrappers/mamba/mamba_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import keras
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Tensor
from bayesflow.utils import logging, keras_kwargs
from bayesflow.utils.decorators import sanitize_input_shape

try:
from mamba_ssm import Mamba
except ImportError:
logging.error("Mamba class is not available. Please, install the mamba-ssm library via `pip install mamba-ssm`.")


@serializable("bayesflow.wrappers")
class MambaBlock(keras.Layer):
"""
Wraps the original Mamba module from, with added functionality for bidirectional processing:
https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py

Copyright (c) 2023, Tri Dao, Albert Gu.
"""

def __init__(
self,
state_dim: int,
conv_dim: int,
feature_dim: int = 16,
expand: int = 2,
bidirectional: bool = True,
dt_min: float = 0.001,
dt_max: float = 0.1,
device: str = "cuda",
**kwargs,
):
"""
A Keras layer implementing a Mamba-based sequence processing block.

This layer applies a Mamba model for sequence modeling, preceded by a
convolutional projection and followed by layer normalization.

Parameters
----------
state_dim : int
The dimension of the state space in the Mamba model.
conv_dim : int
The dimension of the convolutional layer used in Mamba.
feature_dim : int, optional
The feature dimension for input projection and Mamba processing (default is 16).
expand : int, optional
Expansion factor for Mamba's internal dimension (default is 1).
dt_min : float, optional
Minimum delta time for Mamba (default is 0.001).
dt_max : float, optional
Maximum delta time for Mamba (default is 0.1).
device : str, optional
The device to which the Mamba model is moved, typically "cuda" or "cpu" (default is "cuda").
**kwargs :
Additional keyword arguments passed to the `keras.layers.Layer` initializer.
"""

super().__init__(**keras_kwargs(kwargs))

if keras.backend.backend() != "torch":
raise RuntimeError("Mamba is only available using torch backend.")

self.bidirectional = bidirectional

self.mamba = Mamba(
d_model=feature_dim, d_state=state_dim, d_conv=conv_dim, expand=expand, dt_min=dt_min, dt_max=dt_max
).to(device)

self.input_projector = keras.layers.Conv1D(
feature_dim,
kernel_size=1,
strides=1,
)
self.layer_norm = keras.layers.LayerNormalization()

def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor:
"""
Applies the Mamba layer to the input tensor `x`, optionally in a bidirectional manner.

Parameters
----------
x : Tensor
Input tensor of shape `(batch_size, sequence_length, input_dim)`.
training : bool, optional
Whether the layer should behave in training mode (e.g., applying dropout). Default is False.
**kwargs : dict
Additional keyword arguments passed to the internal `_call` method.

Returns
-------
Tensor
Output tensor of shape `(batch_size, sequence_length, feature_dim)` if unidirectional,
or `(batch_size, sequence_length, 2 * feature_dim)` if bidirectional.
"""

out_forward = self._call(x, training=training, **kwargs)
if self.bidirectional:
out_backward = self._call(keras.ops.flip(x, axis=-2), training=training, **kwargs)
return keras.ops.concatenate((out_forward, out_backward), axis=-1)
return out_forward

def _call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor:
x = self.input_projector(x)
h = self.mamba(x)
out = self.layer_norm(h + x, training=training, **kwargs)
return out

@sanitize_input_shape
def build(self, input_shape):
super().build(input_shape)
self.call(keras.ops.zeros(input_shape))
Empty file added tests/test_wrappers/__init__.py
Empty file.
22 changes: 22 additions & 0 deletions tests/test_wrappers/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest


@pytest.fixture()
def inference_network():
from bayesflow.networks import CouplingFlow

return CouplingFlow(depth=2)


@pytest.fixture()
def random_time_series():
import keras

return keras.random.normal(shape=(2, 80, 2))


@pytest.fixture()
def mamba_summary_network():
from bayesflow.wrappers.mamba import Mamba

return Mamba(summary_dim=4, feature_dims=(2, 2), state_dims=(4, 4), conv_dims=(8, 8))
25 changes: 25 additions & 0 deletions tests/test_wrappers/test_mamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pytest

import bayesflow as bf


@pytest.mark.torch
def test_mamba_summary(random_time_series, mamba_summary_network):
out = mamba_summary_network(random_time_series)
# Batch size 2, summary dim 4
assert out.shape == (2, 4)


@pytest.mark.torch
def test_mamba_trains(random_time_series, inference_network, mamba_summary_network):
workflow = bf.BasicWorkflow(
inference_network=inference_network,
summary_network=mamba_summary_network,
inference_variables=["parameters"],
summary_variables=["observables"],
simulator=bf.simulators.SIR(),
)

history = workflow.fit_online(epochs=2, batch_size=8, num_batches_per_epoch=2)
assert "loss" in list(history.history.keys())
assert len(history.history["loss"]) == 2
Loading